@@ -148,6 +148,40 @@ def search_nearest_entities_mock():
148148 yield search_nearest_entities_mock
149149
150150
151+ @pytest .fixture
152+ def transport_mock ():
153+ with mock .patch (
154+ "google.cloud.aiplatform_v1.services.feature_online_store_service.transports.grpc.FeatureOnlineStoreServiceGrpcTransport"
155+ ) as transport :
156+ transport .return_value = mock .MagicMock (autospec = True )
157+ yield transport
158+
159+
160+ @pytest .fixture
161+ def grpc_insecure_channel_mock ():
162+ import grpc
163+
164+ with mock .patch .object (grpc , "insecure_channel" , autospec = True ) as channel :
165+ channel .return_value = mock .MagicMock (autospec = True )
166+ yield channel
167+
168+
169+ @pytest .fixture
170+ def client_mock ():
171+ with mock .patch (
172+ "google.cloud.aiplatform_v1.services.feature_online_store_service.FeatureOnlineStoreServiceClient"
173+ ) as client_mock :
174+ yield client_mock
175+
176+
177+ @pytest .fixture
178+ def utils_client_with_override_mock ():
179+ with mock .patch (
180+ "google.cloud.aiplatform.utils.FeatureOnlineStoreClientWithOverride"
181+ ) as client_mock :
182+ yield client_mock
183+
184+
151185def fv_eq (
152186 fv_to_check : FeatureView ,
153187 name : str ,
@@ -428,6 +462,308 @@ def test_fetch_feature_values_optimized_no_endpoint(
428462 FeatureView (_TEST_OPTIMIZED_FV2_PATH ).read (key = ["key1" ]).to_dict ()
429463
430464
465+ def test_ffv_optimized_psc_with_no_connection_options_raises_error (
466+ get_psc_optimized_fos_mock ,
467+ get_optimized_fv_mock ,
468+ ):
469+ with pytest .raises (ValueError ) as excinfo :
470+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (key = ["key1" ])
471+
472+ assert str (excinfo .value ) == (
473+ "Use `connection_options` to specify an IP address. Required for optimized online store with private service connect."
474+ )
475+
476+
477+ def test_ffv_optimized_psc_with_no_connection_transport_raises_error (
478+ get_psc_optimized_fos_mock ,
479+ get_optimized_fv_mock ,
480+ ):
481+ with pytest .raises (ValueError ) as excinfo :
482+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
483+ key = ["key1" ],
484+ connection_options = fs_utils .ConnectionOptions (
485+ host = "1.2.3.4" , transport = None
486+ ),
487+ )
488+
489+ assert str (excinfo .value ) == (
490+ "Unsupported connection transport type, got transport: None"
491+ )
492+
493+
494+ def test_ffv_optimized_psc_with_bad_connection_transport_raises_error (
495+ get_psc_optimized_fos_mock ,
496+ get_optimized_fv_mock ,
497+ ):
498+ with pytest .raises (ValueError ) as excinfo :
499+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
500+ key = ["key1" ],
501+ connection_options = fs_utils .ConnectionOptions (
502+ host = "1.2.3.4" , transport = "hi"
503+ ),
504+ )
505+
506+ assert str (excinfo .value ) == (
507+ "Unsupported connection transport type, got transport: hi"
508+ )
509+
510+
511+ @pytest .mark .parametrize ("output_type" , ["dict" , "proto" ])
512+ def test_ffv_optimized_psc (
513+ get_psc_optimized_fos_mock ,
514+ get_optimized_fv_mock ,
515+ transport_mock ,
516+ grpc_insecure_channel_mock ,
517+ fetch_feature_values_mock ,
518+ output_type ,
519+ ):
520+ rsp = FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
521+ key = ["key1" ],
522+ connection_options = fs_utils .ConnectionOptions (
523+ host = "1.2.3.4" ,
524+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
525+ ),
526+ )
527+
528+ # Ensure that we create and use insecure channel to the target.
529+ grpc_insecure_channel_mock .assert_called_once_with ("1.2.3.4:10002" )
530+ transport_grpc_channel = transport_mock .call_args .kwargs ["channel" ]
531+ assert transport_grpc_channel == grpc_insecure_channel_mock .return_value
532+
533+ if output_type == "dict" :
534+ assert rsp .to_dict () == {
535+ "features" : [{"name" : "key1" , "value" : {"string_value" : "value1" }}]
536+ }
537+ elif output_type == "proto" :
538+ assert rsp .to_proto () == _TEST_FV_FETCH1
539+
540+
541+ def test_same_connection_options_are_equal ():
542+ opt1 = fs_utils .ConnectionOptions (
543+ host = "1.1.1.1" ,
544+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
545+ )
546+ opt2 = fs_utils .ConnectionOptions (
547+ host = "1.1.1.1" ,
548+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
549+ )
550+ assert opt1 == opt2
551+
552+
553+ def test_different_host_in_connection_options_are_not_equal ():
554+ opt1 = fs_utils .ConnectionOptions (
555+ host = "1.1.1.2" ,
556+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
557+ )
558+ opt2 = fs_utils .ConnectionOptions (
559+ host = "1.1.1.1" ,
560+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
561+ )
562+
563+ assert opt1 != opt2
564+
565+
566+ def test_bad_transport_in_compared_connection_options_raises_error ():
567+ opt1 = fs_utils .ConnectionOptions (
568+ host = "1.1.1.1" ,
569+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
570+ )
571+ opt2 = fs_utils .ConnectionOptions (
572+ host = "1.1.1.1" ,
573+ transport = None ,
574+ )
575+
576+ with pytest .raises (ValueError ) as excinfo :
577+ assert opt1 != opt2
578+
579+ assert str (excinfo .value ) == (
580+ "Transport 'ConnectionOptions.InsecureGrpcChannel()' cannot be compared to transport 'None'."
581+ )
582+
583+
584+ def test_bad_transport_in_connection_options_raises_error ():
585+ opt1 = fs_utils .ConnectionOptions (
586+ host = "1.1.1.1" ,
587+ transport = None ,
588+ )
589+ opt2 = fs_utils .ConnectionOptions (
590+ host = "1.1.1.1" ,
591+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
592+ )
593+
594+ with pytest .raises (ValueError ) as excinfo :
595+ assert opt1 != opt2
596+
597+ assert str (excinfo .value ) == ("Unsupported transport supplied: None" )
598+
599+
600+ def test_same_connection_options_have_same_hash ():
601+ opt1 = fs_utils .ConnectionOptions (
602+ host = "1.1.1.1" ,
603+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
604+ )
605+ opt2 = fs_utils .ConnectionOptions (
606+ host = "1.1.1.1" ,
607+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
608+ )
609+
610+ d = {}
611+ d [opt1 ] = "hi"
612+ assert d [opt2 ] == "hi"
613+
614+
615+ @pytest .mark .parametrize (
616+ "hosts" ,
617+ [
618+ ("1.1.1.1" , "1.1.1.2" ),
619+ ("1.1.1.2" , "1.1.1.1" ),
620+ ("10.0.0.1" , "9.9.9.9" ),
621+ ],
622+ )
623+ def test_different_host_in_connection_options_have_different_hash (hosts ):
624+ opt1 = fs_utils .ConnectionOptions (
625+ host = hosts [0 ],
626+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
627+ )
628+ opt2 = fs_utils .ConnectionOptions (
629+ host = hosts [1 ],
630+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
631+ )
632+
633+ d = {}
634+ d [opt1 ] = "hi"
635+ assert opt2 not in d
636+
637+
638+ @pytest .mark .parametrize (
639+ "transports" ,
640+ [
641+ (fs_utils .ConnectionOptions .InsecureGrpcChannel (), None ),
642+ (None , fs_utils .ConnectionOptions .InsecureGrpcChannel ()),
643+ (None , "hi" ),
644+ ("hi" , None ),
645+ ],
646+ )
647+ def test_bad_transport_in_connection_options_have_different_hash (transports ):
648+ opt1 = fs_utils .ConnectionOptions (
649+ host = "1.1.1.1" ,
650+ transport = transports [0 ],
651+ )
652+ opt2 = fs_utils .ConnectionOptions (
653+ host = "1.1.1.1" ,
654+ transport = transports [1 ],
655+ )
656+
657+ d = {}
658+ d [opt1 ] = "hi"
659+ assert opt2 not in d
660+
661+
662+ def test_diff_host_and_bad_transport_in_connection_options_have_different_hash ():
663+ opt1 = fs_utils .ConnectionOptions (
664+ host = "1.1.1.1" ,
665+ transport = None ,
666+ )
667+ opt2 = fs_utils .ConnectionOptions (
668+ host = "9.9.9.9" ,
669+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
670+ )
671+
672+ d = {}
673+ d [opt1 ] = "hi"
674+ assert opt2 not in d
675+
676+
677+ def test_ffv_optimized_psc_reuse_client_for_same_connection_options_in_same_ffv (
678+ get_psc_optimized_fos_mock ,
679+ get_optimized_fv_mock ,
680+ client_mock ,
681+ transport_mock ,
682+ grpc_insecure_channel_mock ,
683+ fetch_feature_values_mock ,
684+ ):
685+ fv = FeatureView (_TEST_OPTIMIZED_FV1_PATH )
686+ fv .read (
687+ key = ["key1" ],
688+ connection_options = fs_utils .ConnectionOptions (
689+ host = "1.1.1.1" ,
690+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
691+ ),
692+ )
693+ fv .read (
694+ key = ["key2" ],
695+ connection_options = fs_utils .ConnectionOptions (
696+ host = "1.1.1.1" ,
697+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
698+ ),
699+ )
700+
701+ # Insecure channel and transport creation should only be done once.
702+ assert grpc_insecure_channel_mock .call_args_list == [mock .call ("1.1.1.1:10002" )]
703+ assert transport_mock .call_args_list == [
704+ mock .call (channel = grpc_insecure_channel_mock .return_value ),
705+ ]
706+
707+
708+ def test_ffv_optimized_psc_different_client_for_different_connection_options (
709+ get_psc_optimized_fos_mock ,
710+ get_optimized_fv_mock ,
711+ client_mock ,
712+ transport_mock ,
713+ grpc_insecure_channel_mock ,
714+ fetch_feature_values_mock ,
715+ ):
716+ # Return two different grpc channels each time insecure channel is called.
717+ import grpc
718+
719+ grpc_chan1 = mock .MagicMock (spec = grpc .Channel )
720+ grpc_chan2 = mock .MagicMock (spec = grpc .Channel )
721+ grpc_insecure_channel_mock .side_effect = [grpc_chan1 , grpc_chan2 ]
722+
723+ fv = FeatureView (_TEST_OPTIMIZED_FV1_PATH )
724+ fv .read (
725+ key = ["key1" ],
726+ connection_options = fs_utils .ConnectionOptions (
727+ host = "1.1.1.1" ,
728+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
729+ ),
730+ )
731+ fv .read (
732+ key = ["key2" ],
733+ connection_options = fs_utils .ConnectionOptions (
734+ host = "1.2.3.4" ,
735+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
736+ ),
737+ )
738+
739+ # Insecure channel and transport creation should be done twice - one for each different connection.
740+ assert grpc_insecure_channel_mock .call_args_list == [
741+ mock .call ("1.1.1.1:10002" ),
742+ mock .call ("1.2.3.4:10002" ),
743+ ]
744+ assert transport_mock .call_args_list == [
745+ mock .call (channel = grpc_chan1 ),
746+ mock .call (channel = grpc_chan2 ),
747+ ]
748+
749+
750+ def test_ffv_optimized_psc_bad_gapic_client_raises_error (
751+ get_psc_optimized_fos_mock , get_optimized_fv_mock , utils_client_with_override_mock
752+ ):
753+ with pytest .raises (ValueError ) as excinfo :
754+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
755+ key = ["key1" ],
756+ connection_options = fs_utils .ConnectionOptions (
757+ host = "1.1.1.1" ,
758+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
759+ ),
760+ )
761+
762+ assert str (excinfo .value ) == (
763+ f"Unexpected gapic class '{ utils_client_with_override_mock .get_gapic_client_class .return_value } ' used by internal client."
764+ )
765+
766+
431767@pytest .mark .parametrize ("output_type" , ["dict" , "proto" ])
432768def test_search_nearest_entities (
433769 get_esf_optimized_fos_mock ,
0 commit comments