def test_usage_on(dummy_exporter, enabling_toggle): _reload_feast() from feast.feature_store import FeatureStore with tempfile.TemporaryDirectory() as temp_dir: test_feature_store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="fake_project", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online.db")), )) entity = Entity( name="driver_car_id", description="Car driver id", value_type=ValueType.STRING, tags={"team": "matchmaking"}, ) test_feature_store.apply([entity]) assert len(dummy_exporter) == 3 assert { "entrypoint": "feast.infra.local.LocalRegistryStore.get_registry_proto" }.items() <= dummy_exporter[0].items() assert { "entrypoint": "feast.infra.local.LocalRegistryStore.update_registry_proto" }.items() <= dummy_exporter[1].items() assert { "entrypoint": "feast.feature_store.FeatureStore.apply" }.items() <= dummy_exporter[2].items()
def construct_test_environment( test_repo_config: IntegrationTestRepoConfig, test_suite_name: str = "integration_test", ) -> Environment: project = f"{test_suite_name}_{str(uuid.uuid4()).replace('-', '')[:8]}" offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project) offline_store_config = offline_creator.create_offline_store_config() online_store = test_repo_config.online_store with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=project, provider=test_repo_config.provider, offline_store=offline_store_config, online_store=online_store, repo_path=repo_dir_name, ) fs = FeatureStore(config=config) # We need to initialize the registry, because if nothing is applied in the test before tearing down # the feature store, that will cause the teardown method to blow up. fs.registry._initialize_registry() environment = Environment( name=project, test_repo_config=test_repo_config, feature_store=fs, data_source_creator=offline_creator, ) try: yield environment finally: fs.teardown()
def test_telemetry_on(): old_environ = dict(os.environ) test_telemetry_id = str(uuid.uuid4()) os.environ["FEAST_FORCE_TELEMETRY_UUID"] = test_telemetry_id os.environ["FEAST_IS_TELEMETRY_TEST"] = "True" os.environ["FEAST_TELEMETRY"] = "True" with tempfile.TemporaryDirectory() as temp_dir: test_feature_store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="fake_project", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online.db") ), ) ) entity = Entity( name="driver_car_id", description="Car driver id", value_type=ValueType.STRING, labels={"team": "matchmaking"}, ) test_feature_store.apply([entity]) os.environ.clear() os.environ.update(old_environ) ensure_bigquery_telemetry_id_with_retry(test_telemetry_id)
def make_feature_store_yaml(project, test_repo_config, repo_dir_name: Path): offline_creator: DataSourceCreator = test_repo_config.offline_store_creator( project) offline_store_config = offline_creator.create_offline_store_config() online_store = test_repo_config.online_store config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=project, provider=test_repo_config.provider, offline_store=offline_store_config, online_store=online_store, repo_path=str(Path(repo_dir_name)), ) config_dict = config.dict() if (isinstance(config_dict["online_store"], dict) and "redis_type" in config_dict["online_store"]): if str(config_dict["online_store"] ["redis_type"]) == "RedisType.redis_cluster": config_dict["online_store"]["redis_type"] = "redis_cluster" elif str(config_dict["online_store"] ["redis_type"]) == "RedisType.redis": config_dict["online_store"]["redis_type"] = "redis" config_dict["repo_path"] = str(config_dict["repo_path"]) return yaml.safe_dump(config_dict)
def test_usage_off(): old_environ = dict(os.environ) test_usage_id = str(uuid.uuid4()) os.environ["FEAST_IS_USAGE_TEST"] = "True" os.environ["FEAST_USAGE"] = "False" os.environ["FEAST_FORCE_USAGE_UUID"] = test_usage_id with tempfile.TemporaryDirectory() as temp_dir: test_feature_store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="fake_project", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online.db") ), ) ) entity = Entity( name="driver_car_id", description="Car driver id", value_type=ValueType.STRING, labels={"team": "matchmaking"}, ) test_feature_store.apply([entity]) os.environ.clear() os.environ.update(old_environ) sleep(30) rows = read_bigquery_usage_id(test_usage_id) assert rows.total_rows == 0
def test_update_feature_views_with_inferred_features(): file_source = FileSource(name="test", path="test path") entity1 = Entity(name="test1", join_keys=["test_column_1"]) entity2 = Entity(name="test2", join_keys=["test_column_2"]) feature_view_1 = FeatureView( name="test1", entities=[entity1], schema=[ Field(name="feature", dtype=Float32), Field(name="test_column_1", dtype=String), ], source=file_source, ) feature_view_2 = FeatureView( name="test2", entities=[entity1, entity2], schema=[ Field(name="feature", dtype=Float32), Field(name="test_column_1", dtype=String), Field(name="test_column_2", dtype=String), ], source=file_source, ) assert len(feature_view_1.schema) == 2 assert len(feature_view_1.features) == 2 # The entity field should be deleted from the schema and features of the feature view. update_feature_views_with_inferred_features([feature_view_1], [entity1], RepoConfig(provider="local", project="test")) assert len(feature_view_1.schema) == 1 assert len(feature_view_1.features) == 1 assert len(feature_view_2.schema) == 3 assert len(feature_view_2.features) == 3 # The entity fields should be deleted from the schema and features of the feature view. update_feature_views_with_inferred_features( [feature_view_2], [entity1, entity2], RepoConfig(provider="local", project="test"), ) assert len(feature_view_2.schema) == 1 assert len(feature_view_2.features) == 1
def test_update_entities_with_inferred_types_from_feature_views( simple_dataset_1, simple_dataset_2): with prep_file_source( df=simple_dataset_1, event_timestamp_column="ts_1") as file_source, prep_file_source( df=simple_dataset_2, event_timestamp_column="ts_1") as file_source_2: fv1 = FeatureView( name="fv1", entities=["id"], batch_source=file_source, ttl=None, ) fv2 = FeatureView( name="fv2", entities=["id"], batch_source=file_source_2, ttl=None, ) actual_1 = Entity(name="id", join_key="id_join_key") actual_2 = Entity(name="id", join_key="id_join_key") update_entities_with_inferred_types_from_feature_views( [actual_1], [fv1], RepoConfig(provider="local", project="test")) update_entities_with_inferred_types_from_feature_views( [actual_2], [fv2], RepoConfig(provider="local", project="test")) assert actual_1 == Entity(name="id", join_key="id_join_key", value_type=ValueType.INT64) assert actual_2 == Entity(name="id", join_key="id_join_key", value_type=ValueType.STRING) with pytest.raises(RegistryInferenceFailure): # two viable data types update_entities_with_inferred_types_from_feature_views( [Entity(name="id", join_key="id_join_key")], [fv1, fv2], RepoConfig(provider="local", project="test"), )
def test_update_data_sources_with_inferred_event_timestamp_col( simple_dataset_1): df_with_two_viable_timestamp_cols = simple_dataset_1.copy(deep=True) df_with_two_viable_timestamp_cols["ts_2"] = simple_dataset_1["ts_1"] with prep_file_source(df=simple_dataset_1) as file_source: data_sources = [ file_source, simple_bq_source_using_table_ref_arg(simple_dataset_1), simple_bq_source_using_query_arg(simple_dataset_1), ] update_data_sources_with_inferred_event_timestamp_col( data_sources, RepoConfig(provider="local", project="test")) actual_event_timestamp_cols = [ source.event_timestamp_column for source in data_sources ] assert actual_event_timestamp_cols == ["ts_1", "ts_1", "ts_1"] with prep_file_source(df=df_with_two_viable_timestamp_cols) as file_source: with pytest.raises(RegistryInferenceFailure): # two viable event_timestamp_columns update_data_sources_with_inferred_event_timestamp_col( [file_source], RepoConfig(provider="local", project="test"))
def test_update_data_sources_with_inferred_event_timestamp_col( universal_data_sources): (_, _, data_sources) = universal_data_sources data_sources_copy = deepcopy(data_sources) # remove defined timestamp_field to allow for inference for data_source in data_sources_copy.values(): data_source.timestamp_field = None data_source.event_timestamp_column = None update_data_sources_with_inferred_event_timestamp_col( data_sources_copy.values(), RepoConfig(provider="local", project="test"), ) actual_event_timestamp_cols = [ source.timestamp_field for source in data_sources_copy.values() ] assert actual_event_timestamp_cols == ["event_timestamp"] * len( data_sources_copy.values())
def test_usage_off(dummy_exporter, enabling_toggle): enabling_toggle.__bool__.return_value = False _reload_feast() from feast.feature_store import FeatureStore with tempfile.TemporaryDirectory() as temp_dir: test_feature_store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="fake_project", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online.db")), )) entity = Entity( name="driver_car_id", description="Car driver id", value_type=ValueType.STRING, tags={"team": "matchmaking"}, ) test_feature_store.apply([entity]) assert not dummy_exporter
def test_historical_features_from_bigquery_sources(provider_type, infer_event_timestamp_col, capsys): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date, infer_event_timestamp_col) bigquery_dataset = ( f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}" ) with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir: gcp_project = bigquery.Client().project # Orders Query table_id = f"{bigquery_dataset}.orders" stage_orders_bigquery(orders_df, table_id) entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}" # Driver Feature View driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date) driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly" stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id) driver_source = BigQuerySource( table_ref=driver_table_id, event_timestamp_column="datetime", created_timestamp_column="created", ) driver_fv = create_driver_hourly_stats_feature_view(driver_source) # Customer Feature View customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date) customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile" stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id) customer_source = BigQuerySource( table_ref=customer_table_id, event_timestamp_column="datetime", created_timestamp_column="", ) customer_fv = create_customer_daily_profile_feature_view( customer_source) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) customer = Entity(name="customer_id", value_type=ValueType.INT64) if provider_type == "local": store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="local", online_store=SqliteOnlineStoreConfig(path=os.path.join( temp_dir, "online_store.db"), ), offline_store=BigQueryOfflineStoreConfig( type="bigquery", dataset=bigquery_dataset), )) elif provider_type == "gcp": store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10)), provider="gcp", offline_store=BigQueryOfflineStoreConfig( type="bigquery", dataset=bigquery_dataset), )) elif provider_type == "gcp_custom_offline_config": store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10)), provider="gcp", offline_store=BigQueryOfflineStoreConfig(type="bigquery", dataset="foo"), )) else: raise Exception( "Invalid provider used as part of test configuration") store.apply([driver, customer, driver_fv, customer_fv]) event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns else "e_ts") expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, ) job_from_sql = store.get_historical_features( entity_df=entity_df_query, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) start_time = datetime.utcnow() actual_df_from_sql_entities = job_from_sql.to_df() end_time = datetime.utcnow() with capsys.disabled(): print( str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'" )) assert sorted(expected_df.columns) == sorted( actual_df_from_sql_entities.columns) assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), actual_df_from_sql_entities[expected_df.columns].sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), check_dtype=False, ) table_from_sql_entities = job_from_sql.to_arrow() assert_frame_equal(actual_df_from_sql_entities, table_from_sql_entities.to_pandas()) timestamp_column = ("e_ts" if infer_event_timestamp_col else DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL) entity_df_query_with_invalid_join_key = ( f"select order_id, driver_id, customer_id as customer, " f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}" ) # Rename the join key; this should now raise an error. assertpy.assert_that(store.get_historical_features).raises( errors.FeastEntityDFMissingColumnsError).when_called_with( entity_df=entity_df_query_with_invalid_join_key, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) job_from_df = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) # Rename the join key; this should now raise an error. orders_df_with_invalid_join_key = orders_df.rename( {"customer_id": "customer"}, axis="columns") assertpy.assert_that(store.get_historical_features).raises( errors.FeastEntityDFMissingColumnsError).when_called_with( entity_df=orders_df_with_invalid_join_key, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) # Make sure that custom dataset name is being used from the offline_store config if provider_type == "gcp_custom_offline_config": assertpy.assert_that(job_from_df.query).contains("foo.entity_df") else: assertpy.assert_that( job_from_df.query).contains(f"{bigquery_dataset}.entity_df") start_time = datetime.utcnow() actual_df_from_df_entities = job_from_df.to_df() end_time = datetime.utcnow() with capsys.disabled(): print( str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n" )) assert sorted(expected_df.columns) == sorted( actual_df_from_df_entities.columns) assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), actual_df_from_df_entities[expected_df.columns].sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), check_dtype=False, ) table_from_df_entities = job_from_df.to_arrow() assert_frame_equal(actual_df_from_df_entities, table_from_df_entities.to_pandas())
def test_historical_features_from_parquet_sources(infer_event_timestamp_col): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date, infer_event_timestamp_col) with TemporaryDirectory() as temp_dir: driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date) driver_source = stage_driver_hourly_stats_parquet_source( temp_dir, driver_df) driver_fv = create_driver_hourly_stats_feature_view(driver_source) customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date) customer_source = stage_customer_daily_profile_parquet_source( temp_dir, customer_df) customer_fv = create_customer_daily_profile_feature_view( customer_source) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) customer = Entity(name="customer_id", value_type=ValueType.INT64) store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online_store.db")), )) store.apply([driver, customer, driver_fv, customer_fv]) job = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df = job.to_df() event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns else "e_ts") expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, ) assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), actual_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), )
def test_online() -> None: """ Test reading from the online store in local mode. """ runner = CliRunner() with runner.local_repo( get_example_repo("example_feature_repo_1.py")) as store: # Write some data to two tables driver_locations_fv = store.get_feature_view(name="driver_locations") customer_profile_fv = store.get_feature_view(name="customer_profile") customer_driver_combined_fv = store.get_feature_view( name="customer_driver_combined") provider = store._get_provider() driver_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) provider.online_write_batch( project=store.config.project, table=driver_locations_fv, data=[( driver_key, { "lat": ValueProto(double_val=0.1), "lon": ValueProto(string_val="1.0"), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(int64_val=5)]) provider.online_write_batch( project=store.config.project, table=customer_profile_fv, data=[( customer_key, { "avg_orders_day": ValueProto(float_val=1.0), "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto( join_keys=["customer", "driver"], entity_values=[ValueProto(int64_val=5), ValueProto(int64_val=1)], ) provider.online_write_batch( project=store.config.project, table=customer_driver_combined_fv, data=[( customer_key, { "trips": ValueProto(int64_val=7) }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }, { "driver": 1, "customer": 5 }], ).to_dict() assert "driver_locations__lon" in result assert "customer_profile__avg_orders_day" in result assert "customer_profile__name" in result assert result["driver"] == [1, 1] assert result["customer"] == [5, 5] assert result["driver_locations__lon"] == ["1.0", "1.0"] assert result["customer_profile__avg_orders_day"] == [1.0, 1.0] assert result["customer_profile__name"] == ["John", "John"] assert result["customer_driver_combined__trips"] == [7, 7] # Ensure features are still in result when keys not found result = store.get_online_features( feature_refs=["customer_driver_combined:trips"], entity_rows=[{ "driver": 0, "customer": 0 }], ).to_dict() assert "customer_driver_combined__trips" in result # invalid table reference with pytest.raises(FeatureViewNotFoundException): store.get_online_features( feature_refs=["driver_locations_bad:lon"], entity_rows=[{ "driver": 1 }], ) # Create new FeatureStore object with fast cache invalidation cache_ttl = 1 fs_fast_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=cache_ttl), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should download the registry and cache it permanently (or until manually refreshed) result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # Wait for registry to expire time.sleep(cache_ttl) # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file) with pytest.raises(FileNotFoundError): fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() # Restore registry.db so that we can see if it actually reloads registry os.rename(store.config.registry + "_fake", store.config.registry) # Test if registry is actually reloaded and whether results return result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Create a registry with infinite cache (for users that want to manually refresh the registry) fs_infinite_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=0), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should return results (and fill the registry cache) result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Wait a bit so that an arbitrary TTL would take effect time.sleep(2) # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # TTL is infinite so this method should use registry cache result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Force registry reload (should fail because file is missing) with pytest.raises(FileNotFoundError): fs_infinite_ttl.refresh_registry() # Restore registry.db so that teardown works os.rename(store.config.registry + "_fake", store.config.registry)
def construct_test_environment( test_repo_config: TestRepoConfig, create_and_apply: bool = False, materialize: bool = False, ) -> Environment: """ This method should take in the parameters from the test repo config and created a feature repo, apply it, and return the constructed feature store object to callers. This feature store object can be interacted for the purposes of tests. The user is *not* expected to perform any clean up actions. :param test_repo_config: configuration :return: A feature store built using the supplied configuration. """ df = create_dataset() project = f"test_correctness_{str(uuid.uuid4()).replace('-', '')[:8]}" module_name, config_class_name = test_repo_config.offline_store_creator.rsplit( ".", 1) offline_creator: DataSourceCreator = importer.get_class_from_type( module_name, config_class_name, "DataSourceCreator")(project) ds = offline_creator.create_data_source(project, df, field_mapping={ "ts_1": "ts", "id": "driver_id" }) offline_store = offline_creator.create_offline_store_config() online_store = test_repo_config.online_store with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=project, provider=test_repo_config.provider, offline_store=offline_store, online_store=online_store, repo_path=repo_dir_name, ) fs = FeatureStore(config=config) environment = Environment( name=project, test_repo_config=test_repo_config, feature_store=fs, data_source=ds, data_source_creator=offline_creator, ) fvs = [] entities = [] try: if create_and_apply: entities.extend([driver(), customer()]) fvs.extend([ environment.driver_stats_feature_view(), environment.customer_feature_view(), ]) fs.apply(fvs + entities) if materialize: fs.materialize(environment.start_date, environment.end_date) yield environment finally: offline_creator.teardown() fs.teardown()
def construct_test_environment( test_repo_config: IntegrationTestRepoConfig, test_suite_name: str = "integration_test", ) -> Environment: _uuid = str(uuid.uuid4()).replace("-", "")[:8] run_id = os.getenv("GITHUB_RUN_ID", default=None) run_id = f"gh_run_{run_id}_{_uuid}" if run_id else _uuid run_num = os.getenv("GITHUB_RUN_NUMBER", default=1) project = f"{test_suite_name}_{run_id}_{run_num}" offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project) offline_store_config = offline_creator.create_offline_store_config() online_store = test_repo_config.online_store repo_dir_name = tempfile.mkdtemp() if test_repo_config.python_feature_server: from feast.infra.feature_servers.aws_lambda.config import ( AwsLambdaFeatureServerConfig, ) feature_server = AwsLambdaFeatureServerConfig( enabled=True, execution_role_name="arn:aws:iam::402087665549:role/lambda_execution_role", ) registry = f"s3://feast-integration-tests/registries/{project}/registry.db" else: feature_server = None registry = str(Path(repo_dir_name) / "registry.db") config = RepoConfig( registry=registry, project=project, provider=test_repo_config.provider, offline_store=offline_store_config, online_store=online_store, repo_path=repo_dir_name, feature_server=feature_server, ) # Create feature_store.yaml out of the config with open(Path(repo_dir_name) / "feature_store.yaml", "w") as f: yaml.safe_dump(json.loads(config.json()), f) fs = FeatureStore(repo_dir_name) # We need to initialize the registry, because if nothing is applied in the test before tearing down # the feature store, that will cause the teardown method to blow up. fs.registry._initialize_registry() environment = Environment( name=project, test_repo_config=test_repo_config, feature_store=fs, data_source_creator=offline_creator, python_feature_server=test_repo_config.python_feature_server, ) return environment
def test_historical_features_from_bigquery_sources_containing_backfills( capsys): now = datetime.now().replace(microsecond=0, second=0, minute=0) tomorrow = now + timedelta(days=1) entity_dataframe = pd.DataFrame(data=[ { "driver_id": 1001, "event_timestamp": now + timedelta(days=2) }, { "driver_id": 1002, "event_timestamp": now + timedelta(days=2) }, ]) driver_stats_df = pd.DataFrame(data=[ # Duplicated rows simple case { "driver_id": 1001, "avg_daily_trips": 10, "event_timestamp": now, "created": tomorrow, }, { "driver_id": 1001, "avg_daily_trips": 20, "event_timestamp": tomorrow, "created": tomorrow, }, # Duplicated rows after a backfill { "driver_id": 1002, "avg_daily_trips": 30, "event_timestamp": now, "created": tomorrow, }, { "driver_id": 1002, "avg_daily_trips": 40, "event_timestamp": tomorrow, "created": now, }, ]) expected_df = pd.DataFrame(data=[ { "driver_id": 1001, "event_timestamp": now + timedelta(days=2), "avg_daily_trips": 20, }, { "driver_id": 1002, "event_timestamp": now + timedelta(days=2), "avg_daily_trips": 40, }, ]) bigquery_dataset = ( f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}" ) with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir: gcp_project = bigquery.Client().project # Entity Dataframe SQL query table_id = f"{bigquery_dataset}.orders" stage_orders_bigquery(entity_dataframe, table_id) entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}" # Driver Feature View driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly" stage_driver_hourly_stats_bigquery_source(driver_stats_df, driver_table_id) store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10)), provider="gcp", offline_store=BigQueryOfflineStoreConfig(type="bigquery", dataset=bigquery_dataset), )) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) driver_fv = FeatureView( name="driver_stats", entities=["driver"], features=[Feature(name="avg_daily_trips", dtype=ValueType.INT32)], batch_source=BigQuerySource( table_ref=driver_table_id, event_timestamp_column="event_timestamp", created_timestamp_column="created", ), ttl=None, ) store.apply([driver, driver_fv]) try: job_from_sql = store.get_historical_features( entity_df=entity_df_query, features=["driver_stats:avg_daily_trips"], full_feature_names=False, ) start_time = datetime.utcnow() actual_df_from_sql_entities = job_from_sql.to_df() end_time = datetime.utcnow() with capsys.disabled(): print( str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'" )) assert sorted(expected_df.columns) == sorted( actual_df_from_sql_entities.columns) assert_frame_equal( expected_df.sort_values(by=["driver_id"]).reset_index( drop=True), actual_df_from_sql_entities[expected_df.columns].sort_values( by=["driver_id"]).reset_index(drop=True), check_dtype=False, ) finally: store.teardown()
def test_historical_features_from_redshift_sources(provider_type, infer_event_timestamp_col, capsys, full_feature_names): client = aws_utils.get_redshift_data_client("us-west-2") s3 = aws_utils.get_s3_resource("us-west-2") offline_store = RedshiftOfflineStoreConfig( cluster_id="feast-integration-tests", region="us-west-2", user="******", database="feast", s3_staging_location= "s3://feast-integration-tests/redshift/tests/ingestion", iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role", ) start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date, infer_event_timestamp_col) redshift_table_prefix = ( f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}" ) # Stage orders_df to Redshift table_name = f"{redshift_table_prefix}_orders" entity_df_query = f"SELECT * FROM {table_name}" orders_context = aws_utils.temporarily_upload_df_to_redshift( client, offline_store.cluster_id, offline_store.database, offline_store.user, s3, f"{offline_store.s3_staging_location}/copy/{table_name}.parquet", offline_store.iam_role, table_name, orders_df, ) # Stage driver_df to Redshift driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date) driver_table_name = f"{redshift_table_prefix}_driver_hourly" driver_context = aws_utils.temporarily_upload_df_to_redshift( client, offline_store.cluster_id, offline_store.database, offline_store.user, s3, f"{offline_store.s3_staging_location}/copy/{driver_table_name}.parquet", offline_store.iam_role, driver_table_name, driver_df, ) # Stage customer_df to Redshift customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date) customer_table_name = f"{redshift_table_prefix}_customer_profile" customer_context = aws_utils.temporarily_upload_df_to_redshift( client, offline_store.cluster_id, offline_store.database, offline_store.user, s3, f"{offline_store.s3_staging_location}/copy/{customer_table_name}.parquet", offline_store.iam_role, customer_table_name, customer_df, ) with orders_context, driver_context, customer_context, TemporaryDirectory( ) as temp_dir: driver_source = RedshiftSource( table=driver_table_name, event_timestamp_column="event_timestamp", created_timestamp_column="created", ) driver_fv = create_driver_hourly_stats_feature_view(driver_source) customer_source = RedshiftSource( table=customer_table_name, event_timestamp_column="event_timestamp", created_timestamp_column="created", ) customer_fv = create_customer_daily_profile_feature_view( customer_source) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) customer = Entity(name="customer_id", value_type=ValueType.INT64) if provider_type == "local": store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="local", online_store=SqliteOnlineStoreConfig(path=os.path.join( temp_dir, "online_store.db"), ), offline_store=offline_store, )) elif provider_type == "aws": store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10)), provider="aws", online_store=DynamoDBOnlineStoreConfig(region="us-west-2"), offline_store=offline_store, )) else: raise Exception( "Invalid provider used as part of test configuration") store.apply([driver, customer, driver_fv, customer_fv]) try: event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns else "e_ts") expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, full_feature_names, ) job_from_sql = store.get_historical_features( entity_df=entity_df_query, features=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], full_feature_names=full_feature_names, ) start_time = datetime.utcnow() actual_df_from_sql_entities = job_from_sql.to_df() end_time = datetime.utcnow() with capsys.disabled(): print( str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'" )) assert sorted(expected_df.columns) == sorted( actual_df_from_sql_entities.columns) assert_frame_equal( expected_df.sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), actual_df_from_sql_entities[expected_df.columns].sort_values( by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), check_dtype=False, ) table_from_sql_entities = job_from_sql.to_arrow() assert_frame_equal( actual_df_from_sql_entities.sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), table_from_sql_entities.to_pandas().sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), ) timestamp_column = ("e_ts" if infer_event_timestamp_col else DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL) entity_df_query_with_invalid_join_key = ( f"select order_id, driver_id, customer_id as customer, " f"order_is_success, {timestamp_column} FROM {table_name}") # Rename the join key; this should now raise an error. assertpy.assert_that( store.get_historical_features( entity_df=entity_df_query_with_invalid_join_key, features=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ).to_df).raises(errors.FeastEntityDFMissingColumnsError ).when_called_with() job_from_df = store.get_historical_features( entity_df=orders_df, features=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], full_feature_names=full_feature_names, ) # Rename the join key; this should now raise an error. orders_df_with_invalid_join_key = orders_df.rename( {"customer_id": "customer"}, axis="columns") assertpy.assert_that( store.get_historical_features( entity_df=orders_df_with_invalid_join_key, features=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ).to_df).raises(errors.FeastEntityDFMissingColumnsError ).when_called_with() start_time = datetime.utcnow() actual_df_from_df_entities = job_from_df.to_df() end_time = datetime.utcnow() with capsys.disabled(): print( str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n" )) assert sorted(expected_df.columns) == sorted( actual_df_from_df_entities.columns) assert_frame_equal( expected_df.sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), actual_df_from_df_entities[expected_df.columns].sort_values( by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), check_dtype=False, ) table_from_df_entities = job_from_df.to_arrow() assert_frame_equal( actual_df_from_df_entities.sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), table_from_df_entities.to_pandas().sort_values(by=[ event_timestamp, "order_id", "driver_id", "customer_id" ]).reset_index(drop=True), ) finally: store.teardown()