Пример #1
0
def test_usage_on(dummy_exporter, enabling_toggle):
    _reload_feast()
    from feast.feature_store import FeatureStore

    with tempfile.TemporaryDirectory() as temp_dir:
        test_feature_store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="fake_project",
            provider="local",
            online_store=SqliteOnlineStoreConfig(
                path=os.path.join(temp_dir, "online.db")),
        ))
        entity = Entity(
            name="driver_car_id",
            description="Car driver id",
            value_type=ValueType.STRING,
            tags={"team": "matchmaking"},
        )

        test_feature_store.apply([entity])

        assert len(dummy_exporter) == 3
        assert {
            "entrypoint":
            "feast.infra.local.LocalRegistryStore.get_registry_proto"
        }.items() <= dummy_exporter[0].items()
        assert {
            "entrypoint":
            "feast.infra.local.LocalRegistryStore.update_registry_proto"
        }.items() <= dummy_exporter[1].items()
        assert {
            "entrypoint": "feast.feature_store.FeatureStore.apply"
        }.items() <= dummy_exporter[2].items()
Пример #2
0
def construct_test_environment(
    test_repo_config: IntegrationTestRepoConfig,
    test_suite_name: str = "integration_test",
) -> Environment:
    project = f"{test_suite_name}_{str(uuid.uuid4()).replace('-', '')[:8]}"

    offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project)

    offline_store_config = offline_creator.create_offline_store_config()
    online_store = test_repo_config.online_store

    with tempfile.TemporaryDirectory() as repo_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=project,
            provider=test_repo_config.provider,
            offline_store=offline_store_config,
            online_store=online_store,
            repo_path=repo_dir_name,
        )
        fs = FeatureStore(config=config)
        # We need to initialize the registry, because if nothing is applied in the test before tearing down
        # the feature store, that will cause the teardown method to blow up.
        fs.registry._initialize_registry()
        environment = Environment(
            name=project,
            test_repo_config=test_repo_config,
            feature_store=fs,
            data_source_creator=offline_creator,
        )

        try:
            yield environment
        finally:
            fs.teardown()
Пример #3
0
def test_telemetry_on():
    old_environ = dict(os.environ)
    test_telemetry_id = str(uuid.uuid4())
    os.environ["FEAST_FORCE_TELEMETRY_UUID"] = test_telemetry_id
    os.environ["FEAST_IS_TELEMETRY_TEST"] = "True"
    os.environ["FEAST_TELEMETRY"] = "True"

    with tempfile.TemporaryDirectory() as temp_dir:
        test_feature_store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="fake_project",
                provider="local",
                online_store=SqliteOnlineStoreConfig(
                    path=os.path.join(temp_dir, "online.db")
                ),
            )
        )
        entity = Entity(
            name="driver_car_id",
            description="Car driver id",
            value_type=ValueType.STRING,
            labels={"team": "matchmaking"},
        )

        test_feature_store.apply([entity])

        os.environ.clear()
        os.environ.update(old_environ)
        ensure_bigquery_telemetry_id_with_retry(test_telemetry_id)
Пример #4
0
def make_feature_store_yaml(project, test_repo_config, repo_dir_name: Path):
    offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(
        project)

    offline_store_config = offline_creator.create_offline_store_config()
    online_store = test_repo_config.online_store

    config = RepoConfig(
        registry=str(Path(repo_dir_name) / "registry.db"),
        project=project,
        provider=test_repo_config.provider,
        offline_store=offline_store_config,
        online_store=online_store,
        repo_path=str(Path(repo_dir_name)),
    )
    config_dict = config.dict()
    if (isinstance(config_dict["online_store"], dict)
            and "redis_type" in config_dict["online_store"]):
        if str(config_dict["online_store"]
               ["redis_type"]) == "RedisType.redis_cluster":
            config_dict["online_store"]["redis_type"] = "redis_cluster"
        elif str(config_dict["online_store"]
                 ["redis_type"]) == "RedisType.redis":
            config_dict["online_store"]["redis_type"] = "redis"
    config_dict["repo_path"] = str(config_dict["repo_path"])
    return yaml.safe_dump(config_dict)
Пример #5
0
def test_usage_off():
    old_environ = dict(os.environ)
    test_usage_id = str(uuid.uuid4())
    os.environ["FEAST_IS_USAGE_TEST"] = "True"
    os.environ["FEAST_USAGE"] = "False"
    os.environ["FEAST_FORCE_USAGE_UUID"] = test_usage_id

    with tempfile.TemporaryDirectory() as temp_dir:
        test_feature_store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="fake_project",
                provider="local",
                online_store=SqliteOnlineStoreConfig(
                    path=os.path.join(temp_dir, "online.db")
                ),
            )
        )
        entity = Entity(
            name="driver_car_id",
            description="Car driver id",
            value_type=ValueType.STRING,
            labels={"team": "matchmaking"},
        )
        test_feature_store.apply([entity])

        os.environ.clear()
        os.environ.update(old_environ)
        sleep(30)
        rows = read_bigquery_usage_id(test_usage_id)
        assert rows.total_rows == 0
Пример #6
0
def test_update_feature_views_with_inferred_features():
    file_source = FileSource(name="test", path="test path")
    entity1 = Entity(name="test1", join_keys=["test_column_1"])
    entity2 = Entity(name="test2", join_keys=["test_column_2"])
    feature_view_1 = FeatureView(
        name="test1",
        entities=[entity1],
        schema=[
            Field(name="feature", dtype=Float32),
            Field(name="test_column_1", dtype=String),
        ],
        source=file_source,
    )
    feature_view_2 = FeatureView(
        name="test2",
        entities=[entity1, entity2],
        schema=[
            Field(name="feature", dtype=Float32),
            Field(name="test_column_1", dtype=String),
            Field(name="test_column_2", dtype=String),
        ],
        source=file_source,
    )

    assert len(feature_view_1.schema) == 2
    assert len(feature_view_1.features) == 2

    # The entity field should be deleted from the schema and features of the feature view.
    update_feature_views_with_inferred_features([feature_view_1], [entity1],
                                                RepoConfig(provider="local",
                                                           project="test"))
    assert len(feature_view_1.schema) == 1
    assert len(feature_view_1.features) == 1

    assert len(feature_view_2.schema) == 3
    assert len(feature_view_2.features) == 3

    # The entity fields should be deleted from the schema and features of the feature view.
    update_feature_views_with_inferred_features(
        [feature_view_2],
        [entity1, entity2],
        RepoConfig(provider="local", project="test"),
    )
    assert len(feature_view_2.schema) == 1
    assert len(feature_view_2.features) == 1
Пример #7
0
def test_update_entities_with_inferred_types_from_feature_views(
        simple_dataset_1, simple_dataset_2):
    with prep_file_source(
            df=simple_dataset_1,
            event_timestamp_column="ts_1") as file_source, prep_file_source(
                df=simple_dataset_2,
                event_timestamp_column="ts_1") as file_source_2:

        fv1 = FeatureView(
            name="fv1",
            entities=["id"],
            batch_source=file_source,
            ttl=None,
        )
        fv2 = FeatureView(
            name="fv2",
            entities=["id"],
            batch_source=file_source_2,
            ttl=None,
        )

        actual_1 = Entity(name="id", join_key="id_join_key")
        actual_2 = Entity(name="id", join_key="id_join_key")

        update_entities_with_inferred_types_from_feature_views(
            [actual_1], [fv1], RepoConfig(provider="local", project="test"))
        update_entities_with_inferred_types_from_feature_views(
            [actual_2], [fv2], RepoConfig(provider="local", project="test"))
        assert actual_1 == Entity(name="id",
                                  join_key="id_join_key",
                                  value_type=ValueType.INT64)
        assert actual_2 == Entity(name="id",
                                  join_key="id_join_key",
                                  value_type=ValueType.STRING)

        with pytest.raises(RegistryInferenceFailure):
            # two viable data types
            update_entities_with_inferred_types_from_feature_views(
                [Entity(name="id", join_key="id_join_key")],
                [fv1, fv2],
                RepoConfig(provider="local", project="test"),
            )
Пример #8
0
def test_update_data_sources_with_inferred_event_timestamp_col(
        simple_dataset_1):
    df_with_two_viable_timestamp_cols = simple_dataset_1.copy(deep=True)
    df_with_two_viable_timestamp_cols["ts_2"] = simple_dataset_1["ts_1"]

    with prep_file_source(df=simple_dataset_1) as file_source:
        data_sources = [
            file_source,
            simple_bq_source_using_table_ref_arg(simple_dataset_1),
            simple_bq_source_using_query_arg(simple_dataset_1),
        ]
        update_data_sources_with_inferred_event_timestamp_col(
            data_sources, RepoConfig(provider="local", project="test"))
        actual_event_timestamp_cols = [
            source.event_timestamp_column for source in data_sources
        ]

        assert actual_event_timestamp_cols == ["ts_1", "ts_1", "ts_1"]

    with prep_file_source(df=df_with_two_viable_timestamp_cols) as file_source:
        with pytest.raises(RegistryInferenceFailure):
            # two viable event_timestamp_columns
            update_data_sources_with_inferred_event_timestamp_col(
                [file_source], RepoConfig(provider="local", project="test"))
Пример #9
0
def test_update_data_sources_with_inferred_event_timestamp_col(
        universal_data_sources):
    (_, _, data_sources) = universal_data_sources
    data_sources_copy = deepcopy(data_sources)

    # remove defined timestamp_field to allow for inference
    for data_source in data_sources_copy.values():
        data_source.timestamp_field = None
        data_source.event_timestamp_column = None

    update_data_sources_with_inferred_event_timestamp_col(
        data_sources_copy.values(),
        RepoConfig(provider="local", project="test"),
    )
    actual_event_timestamp_cols = [
        source.timestamp_field for source in data_sources_copy.values()
    ]

    assert actual_event_timestamp_cols == ["event_timestamp"] * len(
        data_sources_copy.values())
Пример #10
0
def test_usage_off(dummy_exporter, enabling_toggle):
    enabling_toggle.__bool__.return_value = False

    _reload_feast()
    from feast.feature_store import FeatureStore

    with tempfile.TemporaryDirectory() as temp_dir:
        test_feature_store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="fake_project",
            provider="local",
            online_store=SqliteOnlineStoreConfig(
                path=os.path.join(temp_dir, "online.db")),
        ))
        entity = Entity(
            name="driver_car_id",
            description="Car driver id",
            value_type=ValueType.STRING,
            tags={"team": "matchmaking"},
        )
        test_feature_store.apply([entity])

        assert not dummy_exporter
Пример #11
0
def test_historical_features_from_bigquery_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(type="bigquery",
                                                         dataset="foo"),
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           in orders_df.columns else "e_ts")
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
            event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        start_time = datetime.utcnow()
        actual_df_from_sql_entities = job_from_sql.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_sql_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_sql_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )

        table_from_sql_entities = job_from_sql.to_arrow()
        assert_frame_equal(actual_df_from_sql_entities,
                           table_from_sql_entities.to_pandas())

        timestamp_column = ("e_ts" if infer_event_timestamp_col else
                            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

        entity_df_query_with_invalid_join_key = (
            f"select order_id, driver_id, customer_id as customer, "
            f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}"
        )
        # Rename the join key; this should now raise an error.
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=entity_df_query_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        # Rename the join key; this should now raise an error.
        orders_df_with_invalid_join_key = orders_df.rename(
            {"customer_id": "customer"}, axis="columns")
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=orders_df_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        # Make sure that custom dataset name is being used from the offline_store config
        if provider_type == "gcp_custom_offline_config":
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            assertpy.assert_that(
                job_from_df.query).contains(f"{bigquery_dataset}.entity_df")

        start_time = datetime.utcnow()
        actual_df_from_df_entities = job_from_df.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_df_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_df_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )
        table_from_df_entities = job_from_df.to_arrow()
        assert_frame_equal(actual_df_from_df_entities,
                           table_from_df_entities.to_pandas())
Пример #12
0
def test_historical_features_from_parquet_sources(infer_event_timestamp_col):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    with TemporaryDirectory() as temp_dir:
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_source = stage_driver_hourly_stats_parquet_source(
            temp_dir, driver_df)
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_source = stage_customer_daily_profile_parquet_source(
            temp_dir, customer_df)
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)
        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="local",
            online_store=SqliteOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db")),
        ))

        store.apply([driver, customer, driver_fv, customer_fv])

        job = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df = job.to_df()
        event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           in orders_df.columns else "e_ts")
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
            event_timestamp,
        )
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
        )
Пример #13
0
def test_online() -> None:
    """
    Test reading from the online store in local mode.
    """
    runner = CliRunner()
    with runner.local_repo(
            get_example_repo("example_feature_repo_1.py")) as store:
        # Write some data to two tables

        driver_locations_fv = store.get_feature_view(name="driver_locations")
        customer_profile_fv = store.get_feature_view(name="customer_profile")
        customer_driver_combined_fv = store.get_feature_view(
            name="customer_driver_combined")

        provider = store._get_provider()

        driver_key = EntityKeyProto(join_keys=["driver"],
                                    entity_values=[ValueProto(int64_val=1)])
        provider.online_write_batch(
            project=store.config.project,
            table=driver_locations_fv,
            data=[(
                driver_key,
                {
                    "lat": ValueProto(double_val=0.1),
                    "lon": ValueProto(string_val="1.0"),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(join_keys=["customer"],
                                      entity_values=[ValueProto(int64_val=5)])
        provider.online_write_batch(
            project=store.config.project,
            table=customer_profile_fv,
            data=[(
                customer_key,
                {
                    "avg_orders_day": ValueProto(float_val=1.0),
                    "name": ValueProto(string_val="John"),
                    "age": ValueProto(int64_val=3),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(
            join_keys=["customer", "driver"],
            entity_values=[ValueProto(int64_val=5),
                           ValueProto(int64_val=1)],
        )
        provider.online_write_batch(
            project=store.config.project,
            table=customer_driver_combined_fv,
            data=[(
                customer_key,
                {
                    "trips": ValueProto(int64_val=7)
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        # Retrieve two features using two keys, one valid one non-existing
        result = store.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }, {
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()

        assert "driver_locations__lon" in result
        assert "customer_profile__avg_orders_day" in result
        assert "customer_profile__name" in result
        assert result["driver"] == [1, 1]
        assert result["customer"] == [5, 5]
        assert result["driver_locations__lon"] == ["1.0", "1.0"]
        assert result["customer_profile__avg_orders_day"] == [1.0, 1.0]
        assert result["customer_profile__name"] == ["John", "John"]
        assert result["customer_driver_combined__trips"] == [7, 7]

        # Ensure features are still in result when keys not found
        result = store.get_online_features(
            feature_refs=["customer_driver_combined:trips"],
            entity_rows=[{
                "driver": 0,
                "customer": 0
            }],
        ).to_dict()

        assert "customer_driver_combined__trips" in result

        # invalid table reference
        with pytest.raises(FeatureViewNotFoundException):
            store.get_online_features(
                feature_refs=["driver_locations_bad:lon"],
                entity_rows=[{
                    "driver": 1
                }],
            )

        # Create new FeatureStore object with fast cache invalidation
        cache_ttl = 1
        fs_fast_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=cache_ttl),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should download the registry and cache it permanently (or until manually refreshed)
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # Wait for registry to expire
        time.sleep(cache_ttl)

        # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file)
        with pytest.raises(FileNotFoundError):
            fs_fast_ttl.get_online_features(
                feature_refs=[
                    "driver_locations:lon",
                    "customer_profile:avg_orders_day",
                    "customer_profile:name",
                    "customer_driver_combined:trips",
                ],
                entity_rows=[{
                    "driver": 1,
                    "customer": 5
                }],
            ).to_dict()

        # Restore registry.db so that we can see if it actually reloads registry
        os.rename(store.config.registry + "_fake", store.config.registry)

        # Test if registry is actually reloaded and whether results return
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Create a registry with infinite cache (for users that want to manually refresh the registry)
        fs_infinite_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=0),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should return results (and fill the registry cache)
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Wait a bit so that an arbitrary TTL would take effect
        time.sleep(2)

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # TTL is infinite so this method should use registry cache
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Force registry reload (should fail because file is missing)
        with pytest.raises(FileNotFoundError):
            fs_infinite_ttl.refresh_registry()

        # Restore registry.db so that teardown works
        os.rename(store.config.registry + "_fake", store.config.registry)
Пример #14
0
def construct_test_environment(
    test_repo_config: TestRepoConfig,
    create_and_apply: bool = False,
    materialize: bool = False,
) -> Environment:
    """
    This method should take in the parameters from the test repo config and created a feature repo, apply it,
    and return the constructed feature store object to callers.

    This feature store object can be interacted for the purposes of tests.
    The user is *not* expected to perform any clean up actions.

    :param test_repo_config: configuration
    :return: A feature store built using the supplied configuration.
    """
    df = create_dataset()

    project = f"test_correctness_{str(uuid.uuid4()).replace('-', '')[:8]}"

    module_name, config_class_name = test_repo_config.offline_store_creator.rsplit(
        ".", 1)

    offline_creator: DataSourceCreator = importer.get_class_from_type(
        module_name, config_class_name, "DataSourceCreator")(project)
    ds = offline_creator.create_data_source(project,
                                            df,
                                            field_mapping={
                                                "ts_1": "ts",
                                                "id": "driver_id"
                                            })
    offline_store = offline_creator.create_offline_store_config()
    online_store = test_repo_config.online_store

    with tempfile.TemporaryDirectory() as repo_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=project,
            provider=test_repo_config.provider,
            offline_store=offline_store,
            online_store=online_store,
            repo_path=repo_dir_name,
        )
        fs = FeatureStore(config=config)
        environment = Environment(
            name=project,
            test_repo_config=test_repo_config,
            feature_store=fs,
            data_source=ds,
            data_source_creator=offline_creator,
        )

        fvs = []
        entities = []
        try:
            if create_and_apply:
                entities.extend([driver(), customer()])
                fvs.extend([
                    environment.driver_stats_feature_view(),
                    environment.customer_feature_view(),
                ])
                fs.apply(fvs + entities)

            if materialize:
                fs.materialize(environment.start_date, environment.end_date)

            yield environment
        finally:
            offline_creator.teardown()
            fs.teardown()
Пример #15
0
def construct_test_environment(
    test_repo_config: IntegrationTestRepoConfig,
    test_suite_name: str = "integration_test",
) -> Environment:

    _uuid = str(uuid.uuid4()).replace("-", "")[:8]

    run_id = os.getenv("GITHUB_RUN_ID", default=None)
    run_id = f"gh_run_{run_id}_{_uuid}" if run_id else _uuid
    run_num = os.getenv("GITHUB_RUN_NUMBER", default=1)

    project = f"{test_suite_name}_{run_id}_{run_num}"

    offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project)

    offline_store_config = offline_creator.create_offline_store_config()
    online_store = test_repo_config.online_store

    repo_dir_name = tempfile.mkdtemp()

    if test_repo_config.python_feature_server:
        from feast.infra.feature_servers.aws_lambda.config import (
            AwsLambdaFeatureServerConfig,
        )

        feature_server = AwsLambdaFeatureServerConfig(
            enabled=True,
            execution_role_name="arn:aws:iam::402087665549:role/lambda_execution_role",
        )

        registry = f"s3://feast-integration-tests/registries/{project}/registry.db"
    else:
        feature_server = None
        registry = str(Path(repo_dir_name) / "registry.db")

    config = RepoConfig(
        registry=registry,
        project=project,
        provider=test_repo_config.provider,
        offline_store=offline_store_config,
        online_store=online_store,
        repo_path=repo_dir_name,
        feature_server=feature_server,
    )

    # Create feature_store.yaml out of the config
    with open(Path(repo_dir_name) / "feature_store.yaml", "w") as f:
        yaml.safe_dump(json.loads(config.json()), f)

    fs = FeatureStore(repo_dir_name)
    # We need to initialize the registry, because if nothing is applied in the test before tearing down
    # the feature store, that will cause the teardown method to blow up.
    fs.registry._initialize_registry()
    environment = Environment(
        name=project,
        test_repo_config=test_repo_config,
        feature_store=fs,
        data_source_creator=offline_creator,
        python_feature_server=test_repo_config.python_feature_server,
    )

    return environment
Пример #16
0
def test_historical_features_from_bigquery_sources_containing_backfills(
        capsys):
    now = datetime.now().replace(microsecond=0, second=0, minute=0)
    tomorrow = now + timedelta(days=1)

    entity_dataframe = pd.DataFrame(data=[
        {
            "driver_id": 1001,
            "event_timestamp": now + timedelta(days=2)
        },
        {
            "driver_id": 1002,
            "event_timestamp": now + timedelta(days=2)
        },
    ])

    driver_stats_df = pd.DataFrame(data=[
        # Duplicated rows simple case
        {
            "driver_id": 1001,
            "avg_daily_trips": 10,
            "event_timestamp": now,
            "created": tomorrow,
        },
        {
            "driver_id": 1001,
            "avg_daily_trips": 20,
            "event_timestamp": tomorrow,
            "created": tomorrow,
        },
        # Duplicated rows after a backfill
        {
            "driver_id": 1002,
            "avg_daily_trips": 30,
            "event_timestamp": now,
            "created": tomorrow,
        },
        {
            "driver_id": 1002,
            "avg_daily_trips": 40,
            "event_timestamp": tomorrow,
            "created": now,
        },
    ])

    expected_df = pd.DataFrame(data=[
        {
            "driver_id": 1001,
            "event_timestamp": now + timedelta(days=2),
            "avg_daily_trips": 20,
        },
        {
            "driver_id": 1002,
            "event_timestamp": now + timedelta(days=2),
            "avg_daily_trips": 40,
        },
    ])

    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Entity Dataframe SQL query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(entity_dataframe, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_stats_df,
                                                  driver_table_id)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="".join(
                random.choices(string.ascii_uppercase + string.digits, k=10)),
            provider="gcp",
            offline_store=BigQueryOfflineStoreConfig(type="bigquery",
                                                     dataset=bigquery_dataset),
        ))

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        driver_fv = FeatureView(
            name="driver_stats",
            entities=["driver"],
            features=[Feature(name="avg_daily_trips", dtype=ValueType.INT32)],
            batch_source=BigQuerySource(
                table_ref=driver_table_id,
                event_timestamp_column="event_timestamp",
                created_timestamp_column="created",
            ),
            ttl=None,
        )

        store.apply([driver, driver_fv])

        try:
            job_from_sql = store.get_historical_features(
                entity_df=entity_df_query,
                features=["driver_stats:avg_daily_trips"],
                full_feature_names=False,
            )

            start_time = datetime.utcnow()
            actual_df_from_sql_entities = job_from_sql.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_sql_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=["driver_id"]).reset_index(
                    drop=True),
                actual_df_from_sql_entities[expected_df.columns].sort_values(
                    by=["driver_id"]).reset_index(drop=True),
                check_dtype=False,
            )

        finally:
            store.teardown()
Пример #17
0
def test_historical_features_from_redshift_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys, full_feature_names):
    client = aws_utils.get_redshift_data_client("us-west-2")
    s3 = aws_utils.get_s3_resource("us-west-2")

    offline_store = RedshiftOfflineStoreConfig(
        cluster_id="feast-integration-tests",
        region="us-west-2",
        user="******",
        database="feast",
        s3_staging_location=
        "s3://feast-integration-tests/redshift/tests/ingestion",
        iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
    )

    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    redshift_table_prefix = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    # Stage orders_df to Redshift
    table_name = f"{redshift_table_prefix}_orders"
    entity_df_query = f"SELECT * FROM {table_name}"
    orders_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{table_name}.parquet",
        offline_store.iam_role,
        table_name,
        orders_df,
    )

    # Stage driver_df to Redshift
    driver_df = driver_data.create_driver_hourly_stats_df(
        driver_entities, start_date, end_date)
    driver_table_name = f"{redshift_table_prefix}_driver_hourly"
    driver_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{driver_table_name}.parquet",
        offline_store.iam_role,
        driver_table_name,
        driver_df,
    )

    # Stage customer_df to Redshift
    customer_df = driver_data.create_customer_daily_profile_df(
        customer_entities, start_date, end_date)
    customer_table_name = f"{redshift_table_prefix}_customer_profile"
    customer_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{customer_table_name}.parquet",
        offline_store.iam_role,
        customer_table_name,
        customer_df,
    )

    with orders_context, driver_context, customer_context, TemporaryDirectory(
    ) as temp_dir:
        driver_source = RedshiftSource(
            table=driver_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        customer_source = RedshiftSource(
            table=customer_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=offline_store,
            ))
        elif provider_type == "aws":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="aws",
                online_store=DynamoDBOnlineStoreConfig(region="us-west-2"),
                offline_store=offline_store,
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        try:
            event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               in orders_df.columns else "e_ts")
            expected_df = get_expected_training_df(
                customer_df,
                customer_fv,
                driver_df,
                driver_fv,
                orders_df,
                event_timestamp,
                full_feature_names,
            )

            job_from_sql = store.get_historical_features(
                entity_df=entity_df_query,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            start_time = datetime.utcnow()
            actual_df_from_sql_entities = job_from_sql.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_sql_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_sql_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_sql_entities = job_from_sql.to_arrow()
            assert_frame_equal(
                actual_df_from_sql_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_sql_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )

            timestamp_column = ("e_ts" if infer_event_timestamp_col else
                                DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

            entity_df_query_with_invalid_join_key = (
                f"select order_id, driver_id, customer_id as customer, "
                f"order_is_success, {timestamp_column} FROM {table_name}")
            # Rename the join key; this should now raise an error.
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=entity_df_query_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            job_from_df = store.get_historical_features(
                entity_df=orders_df,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            # Rename the join key; this should now raise an error.
            orders_df_with_invalid_join_key = orders_df.rename(
                {"customer_id": "customer"}, axis="columns")
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=orders_df_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            start_time = datetime.utcnow()
            actual_df_from_df_entities = job_from_df.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_df_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_df_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_df_entities = job_from_df.to_arrow()
            assert_frame_equal(
                actual_df_from_df_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_df_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )
        finally:
            store.teardown()