Exemple #1
0
def _convert_arrow_to_proto(
    table: pyarrow.Table, feature_view: FeatureView
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime,
                Optional[datetime]]]:
    rows_to_write = []
    for row in zip(*table.to_pydict().values()):
        entity_key = EntityKeyProto()
        for entity_name in feature_view.entities:
            entity_key.entity_names.append(entity_name)
            idx = table.column_names.index(entity_name)
            value = python_value_to_proto_value(row[idx])
            entity_key.entity_values.append(value)
        feature_dict = {}
        for feature in feature_view.features:
            idx = table.column_names.index(feature.name)
            value = python_value_to_proto_value(row[idx])
            feature_dict[feature.name] = value
        event_timestamp_idx = table.column_names.index(
            feature_view.input.event_timestamp_column)
        event_timestamp = row[event_timestamp_idx]
        if feature_view.input.created_timestamp_column is not None:
            created_timestamp_idx = table.column_names.index(
                feature_view.input.created_timestamp_column)
            created_timestamp = row[created_timestamp_idx]
        else:
            created_timestamp = None

        rows_to_write.append(
            (entity_key, feature_dict, event_timestamp, created_timestamp))
    return rows_to_write
Exemple #2
0
def _get_table_entity_keys(
    table: FeatureView, entity_keys: List[EntityKeyProto], join_key_map: Dict[str, str],
) -> List[EntityKeyProto]:
    table_join_keys = [join_key_map[entity_name] for entity_name in table.entities]
    required_entities = OrderedDict.fromkeys(sorted(table_join_keys))
    entity_key_protos = []
    for entity_key in entity_keys:
        required_entities_to_values = required_entities.copy()
        for i in range(len(entity_key.join_keys)):
            entity_name = entity_key.join_keys[i]
            entity_value = entity_key.entity_values[i]

            if entity_name in required_entities_to_values:
                if required_entities_to_values[entity_name] is not None:
                    raise ValueError(
                        f"Duplicate entity keys detected. Table {table.name} expects {table_join_keys}. The entity "
                        f"{entity_name} was provided at least twice"
                    )
                required_entities_to_values[entity_name] = entity_value

        entity_names = []
        entity_values = []
        for entity_name, entity_value in required_entities_to_values.items():
            if entity_value is None:
                raise ValueError(
                    f"Table {table.name} expects entity field {table_join_keys}. No entity value was found for "
                    f"{entity_name}"
                )
            entity_names.append(entity_name)
            entity_values.append(entity_value)
        entity_key_protos.append(
            EntityKeyProto(join_keys=entity_names, entity_values=entity_values)
        )
    return entity_key_protos
def test_dynamodb_online_store_online_read_unknown_entity(
        repo_config, dynamodb_online_store):
    """Test DynamoDBOnlineStore online_read method."""
    n_samples = 2
    _create_test_table(PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}",
                       REGION)
    data = _create_n_customer_test_samples(n=n_samples)
    _insert_data_test_table(data, PROJECT,
                            f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION)

    entity_keys, features, *rest = zip(*data)
    # Append a nonsensical entity to search for
    entity_keys = list(entity_keys)
    features = list(features)

    # Have the unknown entity be in the beginning, middle, and end of the list of entities.
    for pos in range(len(entity_keys)):
        entity_keys_with_unknown = deepcopy(entity_keys)
        entity_keys_with_unknown.insert(
            pos,
            EntityKeyProto(join_keys=["customer"],
                           entity_values=[ValueProto(string_val="12359")]),
        )
        features_with_none = deepcopy(features)
        features_with_none.insert(pos, None)
        returned_items = dynamodb_online_store.online_read(
            config=repo_config,
            table=MockFeatureView(
                name=f"{TABLE_NAME}_unknown_entity_{n_samples}"),
            entity_keys=entity_keys_with_unknown,
        )
        assert len(returned_items) == len(entity_keys_with_unknown)
        assert [item[1] for item in returned_items] == list(features_with_none)
        # The order should match the original entity key order
        assert returned_items[pos] == (None, None)
Exemple #4
0
def _convert_arrow_to_proto(
    table: Union[pyarrow.Table, pyarrow.RecordBatch],
    feature_view: FeatureView,
    join_keys: Dict[str, ValueType],
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime,
                Optional[datetime]]]:
    # Avoid ChunkedArrays which guarentees `zero_copy_only` availiable.
    if isinstance(table, pyarrow.Table):
        table = table.to_batches()[0]

    columns = [(field.name, field.dtype.to_value_type())
               for field in feature_view.schema] + list(join_keys.items())

    proto_values_by_column = {
        column: python_values_to_proto_values(
            table.column(column).to_numpy(zero_copy_only=False), value_type)
        for column, value_type in columns
    }

    entity_keys = [
        EntityKeyProto(
            join_keys=join_keys,
            entity_values=[proto_values_by_column[k][idx] for k in join_keys],
        ) for idx in range(table.num_rows)
    ]

    # Serialize the features per row
    feature_dict = {
        feature.name: proto_values_by_column[feature.name]
        for feature in feature_view.features
    }
    features = [
        dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())
    ]

    # Convert event_timestamps
    event_timestamps = [
        _coerce_datetime(val) for val in pandas.to_datetime(
            table.column(feature_view.batch_source.timestamp_field).to_numpy(
                zero_copy_only=False))
    ]

    # Convert created_timestamps if they exist
    if feature_view.batch_source.created_timestamp_column:
        created_timestamps = [
            _coerce_datetime(val) for val in pandas.to_datetime(
                table.column(feature_view.batch_source.created_timestamp_column
                             ).to_numpy(zero_copy_only=False))
        ]
    else:
        created_timestamps = [None] * table.num_rows

    return list(
        zip(entity_keys, features, event_timestamps, created_timestamps))
Exemple #5
0
def _create_n_customer_test_samples(n=10):
    return [(
        EntityKeyProto(join_keys=["customer"],
                       entity_values=[ValueProto(string_val=str(i))]),
        {
            "avg_orders_day": ValueProto(float_val=1.0),
            "name": ValueProto(string_val="John"),
            "age": ValueProto(int64_val=3),
        },
        datetime.utcnow(),
        None,
    ) for i in range(n)]
Exemple #6
0
def _convert_arrow_to_proto(
    table: pyarrow.Table,
    feature_view: FeatureView,
    join_keys: List[str],
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime,
                Optional[datetime]]]:
    rows_to_write = []

    def _coerce_datetime(ts):
        """
        Depending on underlying time resolution, arrow to_pydict() sometimes returns pandas
        timestamp type (for nanosecond resolution), and sometimes you get standard python datetime
        (for microsecond resolution).

        While pandas timestamp class is a subclass of python datetime, it doesn't always behave the
        same way. We convert it to normal datetime so that consumers downstream don't have to deal
        with these quirks.
        """

        if isinstance(ts, pandas.Timestamp):
            return ts.to_pydatetime()
        else:
            return ts

    column_names_idx = {k: i for i, k in enumerate(table.column_names)}
    for row in zip(*table.to_pydict().values()):
        entity_key = EntityKeyProto()
        for join_key in join_keys:
            entity_key.join_keys.append(join_key)
            idx = column_names_idx[join_key]
            value = python_value_to_proto_value(row[idx])
            entity_key.entity_values.append(value)
        feature_dict = {}
        for feature in feature_view.features:
            idx = column_names_idx[feature.name]
            value = python_value_to_proto_value(row[idx], feature.dtype)
            feature_dict[feature.name] = value
        event_timestamp_idx = column_names_idx[
            feature_view.batch_source.event_timestamp_column]
        event_timestamp = _coerce_datetime(row[event_timestamp_idx])

        if feature_view.batch_source.created_timestamp_column:
            created_timestamp_idx = column_names_idx[
                feature_view.batch_source.created_timestamp_column]
            created_timestamp = _coerce_datetime(row[created_timestamp_idx])
        else:
            created_timestamp = None

        rows_to_write.append(
            (entity_key, feature_dict, event_timestamp, created_timestamp))
    return rows_to_write
Exemple #7
0
def _convert_arrow_to_proto(
    table: Union[pyarrow.Table, pyarrow.RecordBatch],
    feature_view: FeatureView,
    join_keys: List[str],
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime,
                Optional[datetime]]]:
    # Handle join keys
    join_key_values = {k: table.column(k).to_pylist() for k in join_keys}
    entity_keys = [
        EntityKeyProto(
            join_keys=join_keys,
            entity_values=[
                python_value_to_proto_value(join_key_values[k][idx])
                for k in join_keys
            ],
        ) for idx in range(table.num_rows)
    ]

    # Serialize the features per row
    feature_dict = {
        feature.name: [
            python_value_to_proto_value(val, feature.dtype)
            for val in table.column(feature.name).to_pylist()
        ]
        for feature in feature_view.features
    }
    features = [
        dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())
    ]

    # Convert event_timestamps
    event_timestamps = [
        _coerce_datetime(val) for val in table.column(
            feature_view.batch_source.event_timestamp_column).to_pylist()
    ]

    # Convert created_timestamps if they exist
    if feature_view.batch_source.created_timestamp_column:
        created_timestamps = [
            _coerce_datetime(val)
            for val in table.column(feature_view.batch_source.
                                    created_timestamp_column).to_pylist()
        ]
    else:
        created_timestamps = [None] * table.num_rows

    return list(
        zip(entity_keys, features, event_timestamps, created_timestamps))
Exemple #8
0
    def teardown_infra(
        self,
        project: str,
        tables: Sequence[Union[FeatureTable, FeatureView]],
        entities: Sequence[Entity],
    ) -> None:
        # according to the repos_operations.py we can delete the whole project
        client = self._get_client()

        tables_join_keys = [[e for e in t.entities] for t in tables]
        for table_join_keys in tables_join_keys:
            redis_key_bin = _redis_key(
                project, EntityKeyProto(join_keys=table_join_keys)
            )
            keys = [k for k in client.scan_iter(match=f"{redis_key_bin}*", count=100)]
            if keys:
                client.unlink(*keys)
Exemple #9
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ Values with an older event_ts should overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(-1000, "OLD"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
    )
    """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
Exemple #10
0
def _entity_row_to_key(
        row: GetOnlineFeaturesRequestV2.EntityRow) -> EntityKeyProto:
    names, values = zip(*row.fields.items())
    return EntityKeyProto(join_keys=names,
                          entity_values=values)  # type: ignore
Exemple #11
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific  tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver_id"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """A helper function to write values and read them back"""
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver_id": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver_id"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
Exemple #12
0
def test_online() -> None:
    """
    Test reading from the online store in local mode.
    """
    runner = CliRunner()
    with runner.local_repo(
            get_example_repo("example_feature_repo_1.py")) as store:
        # Write some data to two tables

        driver_locations_fv = store.get_feature_view(name="driver_locations")
        customer_profile_fv = store.get_feature_view(name="customer_profile")
        customer_driver_combined_fv = store.get_feature_view(
            name="customer_driver_combined")

        provider = store._get_provider()

        driver_key = EntityKeyProto(join_keys=["driver"],
                                    entity_values=[ValueProto(int64_val=1)])
        provider.online_write_batch(
            project=store.config.project,
            table=driver_locations_fv,
            data=[(
                driver_key,
                {
                    "lat": ValueProto(double_val=0.1),
                    "lon": ValueProto(string_val="1.0"),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(join_keys=["customer"],
                                      entity_values=[ValueProto(int64_val=5)])
        provider.online_write_batch(
            project=store.config.project,
            table=customer_profile_fv,
            data=[(
                customer_key,
                {
                    "avg_orders_day": ValueProto(float_val=1.0),
                    "name": ValueProto(string_val="John"),
                    "age": ValueProto(int64_val=3),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(
            join_keys=["customer", "driver"],
            entity_values=[ValueProto(int64_val=5),
                           ValueProto(int64_val=1)],
        )
        provider.online_write_batch(
            project=store.config.project,
            table=customer_driver_combined_fv,
            data=[(
                customer_key,
                {
                    "trips": ValueProto(int64_val=7)
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        # Retrieve two features using two keys, one valid one non-existing
        result = store.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }, {
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()

        assert "driver_locations__lon" in result
        assert "customer_profile__avg_orders_day" in result
        assert "customer_profile__name" in result
        assert result["driver"] == [1, 1]
        assert result["customer"] == [5, 5]
        assert result["driver_locations__lon"] == ["1.0", "1.0"]
        assert result["customer_profile__avg_orders_day"] == [1.0, 1.0]
        assert result["customer_profile__name"] == ["John", "John"]
        assert result["customer_driver_combined__trips"] == [7, 7]

        # Ensure features are still in result when keys not found
        result = store.get_online_features(
            feature_refs=["customer_driver_combined:trips"],
            entity_rows=[{
                "driver": 0,
                "customer": 0
            }],
        ).to_dict()

        assert "customer_driver_combined__trips" in result

        # invalid table reference
        with pytest.raises(FeatureViewNotFoundException):
            store.get_online_features(
                feature_refs=["driver_locations_bad:lon"],
                entity_rows=[{
                    "driver": 1
                }],
            )

        # Create new FeatureStore object with fast cache invalidation
        cache_ttl = 1
        fs_fast_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=cache_ttl),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should download the registry and cache it permanently (or until manually refreshed)
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # Wait for registry to expire
        time.sleep(cache_ttl)

        # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file)
        with pytest.raises(FileNotFoundError):
            fs_fast_ttl.get_online_features(
                feature_refs=[
                    "driver_locations:lon",
                    "customer_profile:avg_orders_day",
                    "customer_profile:name",
                    "customer_driver_combined:trips",
                ],
                entity_rows=[{
                    "driver": 1,
                    "customer": 5
                }],
            ).to_dict()

        # Restore registry.db so that we can see if it actually reloads registry
        os.rename(store.config.registry + "_fake", store.config.registry)

        # Test if registry is actually reloaded and whether results return
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Create a registry with infinite cache (for users that want to manually refresh the registry)
        fs_infinite_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=0),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should return results (and fill the registry cache)
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Wait a bit so that an arbitrary TTL would take effect
        time.sleep(2)

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # TTL is infinite so this method should use registry cache
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Force registry reload (should fail because file is missing)
        with pytest.raises(FileNotFoundError):
            fs_infinite_ttl.refresh_registry()

        # Restore registry.db so that teardown works
        os.rename(store.config.registry + "_fake", store.config.registry)
def test_online() -> None:
    """
    Test reading from the online store in local mode.
    """
    runner = CliRunner()
    with runner.local_repo(get_example_repo("example_feature_repo_1.py")) as store:

        # Write some data to two tables
        registry = store._get_registry()
        table = registry.get_feature_view(
            project=store.config.project, name="driver_locations"
        )
        table_2 = registry.get_feature_view(
            project=store.config.project, name="driver_locations_2"
        )

        provider = store._get_provider()

        entity_key = EntityKeyProto(
            entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]
        )
        provider.online_write_batch(
            project=store.config.project,
            table=table,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=0.1),
                        "lon": ValueProto(string_val="1.0"),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )
            ],
            progress=None,
        )

        provider.online_write_batch(
            project=store.config.project,
            table=table_2,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=2.0),
                        "lon": ValueProto(string_val="2.0"),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )
            ],
            progress=None,
        )

        # Retrieve two features using two keys, one valid one non-existing
        result = store.get_online_features(
            feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
            entity_rows=[{"driver": 1}, {"driver": 123}],
        )

        assert "driver_locations:lon" in result.to_dict()
        assert result.to_dict()["driver_locations:lon"] == ["1.0", None]
        assert result.to_dict()["driver_locations_2:lon"] == ["2.0", None]

        # invalid table reference
        with pytest.raises(ValueError):
            store.get_online_features(
                feature_refs=["driver_locations_bad:lon"], entity_rows=[{"driver": 1}],
            )
Exemple #14
0
def test_online_to_df():
    """
    Test dataframe conversion. Make sure the response columns and rows are
    the same order as the request.
    """

    driver_ids = [1, 2, 3]
    customer_ids = [4, 5, 6]
    name = "foo"
    lon_multiply = 1.0
    lat_multiply = 0.1
    age_multiply = 10
    avg_order_day_multiply = 1.0

    runner = CliRunner()
    with runner.local_repo(get_example_repo("example_feature_repo_1.py"),
                           "bigquery") as store:
        # Write three tables to online store
        driver_locations_fv = store.get_feature_view(name="driver_locations")
        customer_profile_fv = store.get_feature_view(name="customer_profile")
        customer_driver_combined_fv = store.get_feature_view(
            name="customer_driver_combined")
        provider = store._get_provider()

        for (d, c) in zip(driver_ids, customer_ids):
            """
            driver table:
            driver driver_locations__lon  driver_locations__lat
                1                   1.0                    0.1
                2                   2.0                    0.2
                3                   3.0                    0.3
            """
            driver_key = EntityKeyProto(
                join_keys=["driver"], entity_values=[ValueProto(int64_val=d)])
            provider.online_write_batch(
                config=store.config,
                table=driver_locations_fv,
                data=[(
                    driver_key,
                    {
                        "lat": ValueProto(double_val=d * lat_multiply),
                        "lon": ValueProto(string_val=str(d * lon_multiply)),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )
            """
            customer table
            customer  customer_profile__avg_orders_day customer_profile__name  customer_profile__age
                4                               4.0                  foo4                     40
                5                               5.0                  foo5                     50
                6                               6.0                  foo6                     60
            """
            customer_key = EntityKeyProto(
                join_keys=["customer"],
                entity_values=[ValueProto(int64_val=c)])
            provider.online_write_batch(
                config=store.config,
                table=customer_profile_fv,
                data=[(
                    customer_key,
                    {
                        "avg_orders_day":
                        ValueProto(float_val=c * avg_order_day_multiply),
                        "name":
                        ValueProto(string_val=name + str(c)),
                        "age":
                        ValueProto(int64_val=c * age_multiply),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )
            """
            customer_driver_combined table
            customer  driver  customer_driver_combined__trips
                4       1                                4
                5       2                               10
                6       3                               18
            """
            combo_keys = EntityKeyProto(
                join_keys=["customer", "driver"],
                entity_values=[
                    ValueProto(int64_val=c),
                    ValueProto(int64_val=d)
                ],
            )
            provider.online_write_batch(
                config=store.config,
                table=customer_driver_combined_fv,
                data=[(
                    combo_keys,
                    {
                        "trips": ValueProto(int64_val=c * d)
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )

        # Get online features in dataframe
        result_df = store.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "driver_locations:lat",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_profile:age",
                "customer_driver_combined:trips",
            ],
            # Reverse the row order
            entity_rows=[{
                "driver": d,
                "customer": c
            } for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))],
        ).to_df()
        """
        Construct the expected dataframe with reversed row order like so:
        driver  customer driver_locations__lon  driver_locations__lat  customer_profile__avg_orders_day customer_profile__name  customer_profile__age  customer_driver_combined__trips
            3       6         3.0                   0.3                    6.0                               foo6                   60                     18
            2       5         2.0                   0.2                    5.0                               foo5                   50                     10
            1       4         1.0                   0.1                    4.0                               foo4                   40                     4
        """
        df_dict = {
            "driver":
            driver_ids,
            "customer":
            customer_ids,
            "driver_locations__lon":
            [str(d * lon_multiply) for d in driver_ids],
            "driver_locations__lat": [d * lat_multiply for d in driver_ids],
            "customer_profile__avg_orders_day":
            [c * avg_order_day_multiply for c in customer_ids],
            "customer_profile__name": [name + str(c) for c in customer_ids],
            "customer_profile__age": [c * age_multiply for c in customer_ids],
            "customer_driver_combined__trips":
            [d * c for (d, c) in zip(driver_ids, customer_ids)],
        }
        # Requested column order
        ordered_column = [
            "driver",
            "customer",
            "driver_locations__lon",
            "driver_locations__lat",
            "customer_profile__avg_orders_day",
            "customer_profile__name",
            "customer_profile__age",
            "customer_driver_combined__trips",
        ]
        expected_df = pd.DataFrame(
            {k: reversed(v)
             for (k, v) in df_dict.items()})
        assert_frame_equal(result_df[ordered_column], expected_df)
def basic_rw_test(store: FeatureStore, view_name: str) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(entity_names=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            project=store.project,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        read_rows = provider.online_read(project=store.project,
                                         table=table,
                                         entity_keys=[entity_key])
        assert len(read_rows) == 1
        _, val = read_rows[0]
        assert val["lon"].string_val == expect_lon
        assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an older event_ts should not overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(1.1, "3.1"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
Exemple #16
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6