Beispiel #1
0
    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            project=store.project,
            table=table,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=write_lat),
                        "lon": ValueProto(string_val=write_lon),
                    },
                    event_ts,
                    created_ts,
                )
            ],
            progress=None,
        )

        read_rows = provider.online_read(
            project=store.project, table=table, entity_keys=[entity_key]
        )
        assert len(read_rows) == 1
        _, val = read_rows[0]
        assert val["lon"].string_val == expect_lon
        assert abs(val["lat"].double_val - expect_lat) < 1e-6
Beispiel #2
0
def _create_n_customer_test_samples(n=10):
    return [(
        EntityKeyProto(join_keys=["customer"],
                       entity_values=[ValueProto(string_val=str(i))]),
        {
            "avg_orders_day": ValueProto(float_val=1.0),
            "name": ValueProto(string_val="John"),
            "age": ValueProto(int64_val=3),
        },
        datetime.utcnow(),
        None,
    ) for i in range(n)]
Beispiel #3
0
    def _get_features_for_entity(
        self,
        values: List[ByteString],
        feature_view: str,
        requested_features: List[str],
    ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]:
        res_val = dict(zip(requested_features, values))

        res_ts = Timestamp()
        ts_val = res_val.pop(f"_ts:{feature_view}")
        if ts_val:
            res_ts.ParseFromString(ts_val)

        res = {}
        for feature_name, val_bin in res_val.items():
            val = ValueProto()
            if val_bin:
                val.ParseFromString(val_bin)
            res[feature_name] = val

        if not res:
            return None, None
        else:
            timestamp = datetime.fromtimestamp(res_ts.seconds)
            return timestamp, res
Beispiel #4
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:

        online_config = config.online_store
        assert isinstance(online_config, DatastoreOnlineStoreConfig)
        client = self._get_client(online_config)

        feast_project = config.project

        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            document_id = compute_datastore_entity_id(entity_key)
            key = client.key("Project", feast_project, "Table", table.name,
                             "Row", document_id)
            value = client.get(key)
            if value is not None:
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin)
                    res[feature_name] = val
                result.append((value["event_ts"], res))
            else:
                result.append((None, None))
        return result
Beispiel #5
0
    def online_read(
        self,
        config: RepoConfig,
        table: FeatureView,
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        online_config = config.online_store
        assert isinstance(online_config, DynamoDBOnlineStoreConfig)
        dynamodb_resource = self._get_dynamodb_resource(online_config.region)

        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            table_instance = dynamodb_resource.Table(
                _get_table_name(config, table))
            entity_id = compute_entity_id(entity_key)
            with tracing_span(name="remote_call"):
                response = table_instance.get_item(
                    Key={"entity_id": entity_id})
            value = response.get("Item")

            if value is not None:
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin.value)
                    res[feature_name] = val
                result.append((datetime.fromisoformat(value["event_ts"]), res))
            else:
                result.append((None, None))
        return result
Beispiel #6
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        online_config = config.online_store
        assert isinstance(online_config, DynamoDBOnlineStoreConfig)
        _, dynamodb_resource = self._initialize_dynamodb(online_config)

        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            table_instance = dynamodb_resource.Table(
                f"{config.project}.{table.name}")
            entity_id = compute_entity_id(entity_key)
            response = table_instance.get_item(Key={"entity_id": entity_id})
            value = response.get("Item")

            if value is not None:
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin.value)
                    res[feature_name] = val
                result.append((value["event_ts"], res))
            else:
                result.append((None, None))
        return result
def test_dynamodb_online_store_online_read_unknown_entity(
        repo_config, dynamodb_online_store):
    """Test DynamoDBOnlineStore online_read method."""
    n_samples = 2
    _create_test_table(PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}",
                       REGION)
    data = _create_n_customer_test_samples(n=n_samples)
    _insert_data_test_table(data, PROJECT,
                            f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION)

    entity_keys, features, *rest = zip(*data)
    # Append a nonsensical entity to search for
    entity_keys = list(entity_keys)
    features = list(features)

    # Have the unknown entity be in the beginning, middle, and end of the list of entities.
    for pos in range(len(entity_keys)):
        entity_keys_with_unknown = deepcopy(entity_keys)
        entity_keys_with_unknown.insert(
            pos,
            EntityKeyProto(join_keys=["customer"],
                           entity_values=[ValueProto(string_val="12359")]),
        )
        features_with_none = deepcopy(features)
        features_with_none.insert(pos, None)
        returned_items = dynamodb_online_store.online_read(
            config=repo_config,
            table=MockFeatureView(
                name=f"{TABLE_NAME}_unknown_entity_{n_samples}"),
            entity_keys=entity_keys_with_unknown,
        )
        assert len(returned_items) == len(entity_keys_with_unknown)
        assert [item[1] for item in returned_items] == list(features_with_none)
        # The order should match the original entity key order
        assert returned_items[pos] == (None, None)
Beispiel #8
0
    def online_read(
        self,
        project: str,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        client = self._initialize_client()

        result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            document_id = compute_datastore_entity_id(entity_key)
            key = client.key(
                "Project", project, "Table", table.name, "Row", document_id
            )
            value = client.get(key)
            if value is not None:
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin)
                    res[feature_name] = val
                result.append((value["event_ts"], res))
            else:
                result.append((None, None))
        return result
Beispiel #9
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        pass
        conn = self._get_conn(config)
        cur = conn.cursor()

        result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

        project = config.project
        for entity_key in entity_keys:
            entity_key_bin = serialize_entity_key(entity_key)

            cur.execute(
                f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = ?",
                (entity_key_bin,),
            )

            res = {}
            res_ts = None
            for feature_name, val_bin, ts in cur.fetchall():
                val = ValueProto()
                val.ParseFromString(val_bin)
                res[feature_name] = val
                res_ts = ts

            if not res:
                result.append((None, None))
            else:
                result.append((res_ts, res))
        return result
Beispiel #10
0
    def online_read(
        self,
        project: str,
        table: Union[FeatureTable, FeatureView],
        entity_key: EntityKeyProto,
    ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]:
        entity_key_bin = serialize_entity_key(entity_key)

        conn = self._get_conn()
        cur = conn.cursor()
        cur.execute(
            f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = ?",
            (entity_key_bin, ),
        )

        res = {}
        res_ts = None
        for feature_name, val_bin, ts in cur.fetchall():
            val = ValueProto()
            val.ParseFromString(val_bin)
            res[feature_name] = val
            res_ts = ts

        if not res:
            return None, None
        else:
            return res_ts, res
Beispiel #11
0
    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=write_lat),
                        "lon": ValueProto(string_val=write_lon),
                    },
                    event_ts,
                    created_ts,
                )
            ],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(
                features=feature_service, entity_rows=[entity_dict]
            ).to_dict()
            assert len(features["driver"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(
                config=store.config, table=table, entity_keys=[entity_key]
            )
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6
Beispiel #12
0
    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            project=project_name,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
            )],
            created_ts=created_ts,
        )

        _, val = provider.online_read(project=project_name,
                                      table=table,
                                      entity_key=entity_key)
        assert val["lon"].string_val == expect_lon
        assert abs(val["lat"].double_val - expect_lat) < 1e-6
Beispiel #13
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        online_store_config = config.online_store
        assert isinstance(online_store_config, RedisOnlineStoreConfig)

        client = self._get_client(online_store_config)
        feature_view = table.name
        project = config.project

        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []

        if not requested_features:
            requested_features = [f.name for f in table.features]

        for entity_key in entity_keys:
            redis_key_bin = _redis_key(project, entity_key)
            hset_keys = [
                _mmh3(f"{feature_view}:{k}") for k in requested_features
            ]
            ts_key = f"_ts:{feature_view}"
            hset_keys.append(ts_key)
            values = client.hmget(redis_key_bin, hset_keys)
            requested_features.append(ts_key)
            res_val = dict(zip(requested_features, values))

            res_ts = Timestamp()
            ts_val = res_val.pop(ts_key)
            if ts_val:
                res_ts.ParseFromString(ts_val)

            res = {}
            for feature_name, val_bin in res_val.items():
                val = ValueProto()
                if val_bin:
                    val.ParseFromString(val_bin)
                res[feature_name] = val

            if not res:
                result.append((None, None))
            else:
                timestamp = datetime.fromtimestamp(res_ts.seconds)
                result.append((timestamp, res))
        return result
Beispiel #14
0
    def online_read(
        self,
        config: RepoConfig,
        table: FeatureView,
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []

        project = config.project
        with self._get_conn(config) as conn, conn.cursor() as cur:
            # Collecting all the keys to a list allows us to make fewer round trips
            # to PostgreSQL
            keys = []
            for entity_key in entity_keys:
                keys.append(serialize_entity_key(entity_key))

            cur.execute(
                sql.SQL("""
                    SELECT entity_key, feature_name, value, event_ts
                    FROM {} WHERE entity_key = ANY(%s);
                    """).format(sql.Identifier(_table_id(project, table)), ),
                (keys, ),
            )

            rows = cur.fetchall()

            # Since we don't know the order returned from PostgreSQL we'll need
            # to construct a dict to be able to quickly look up the correct row
            # when we iterate through the keys since they are in the correct order
            values_dict = defaultdict(list)
            for row in rows if rows is not None else []:
                values_dict[row[0].tobytes()].append(row[1:])

            for key in keys:
                if key in values_dict:
                    value = values_dict[key]
                    res = {}
                    for feature_name, value_bin, event_ts in value:
                        val = ValueProto()
                        val.ParseFromString(value_bin)
                        res[feature_name] = val
                    result.append((event_ts, res))
                else:
                    result.append((None, None))

        return result
Beispiel #15
0
    def online_read(
        self,
        config: RepoConfig,
        table: FeatureView,
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        conn = self._get_conn(config)
        cur = conn.cursor()

        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []

        with tracing_span(name="remote_call"):
            # Fetch all entities in one go
            cur.execute(
                f"SELECT entity_key, feature_name, value, event_ts "
                f"FROM {_table_id(config.project, table)} "
                f"WHERE entity_key IN ({','.join('?' * len(entity_keys))}) "
                f"ORDER BY entity_key",
                [
                    serialize_entity_key(entity_key)
                    for entity_key in entity_keys
                ],
            )
            rows = cur.fetchall()

        rows = {
            k: list(group)
            for k, group in itertools.groupby(rows, key=lambda r: r[0])
        }
        for entity_key in entity_keys:
            entity_key_bin = serialize_entity_key(entity_key)
            res = {}
            res_ts = None
            for _, feature_name, val_bin, ts in rows.get(entity_key_bin, []):
                val = ValueProto()
                val.ParseFromString(val_bin)
                res[feature_name] = val
                res_ts = ts

            if not res:
                result.append((None, None))
            else:
                result.append((res_ts, res))
        return result
Beispiel #16
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:

        online_config = config.online_store
        assert isinstance(online_config, DatastoreOnlineStoreConfig)
        client = self._get_client(online_config)

        feast_project = config.project

        keys: List[Key] = []
        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            document_id = compute_entity_id(entity_key)
            key = client.key("Project", feast_project, "Table", table.name,
                             "Row", document_id)
            keys.append(key)

        # NOTE: get_multi doesn't return values in the same order as the keys in the request.
        # Also, len(values) can be less than len(keys) in the case of missing values.
        with tracing_span(name="remote_call"):
            values = client.get_multi(keys)
        values_dict = {v.key: v for v in values} if values is not None else {}
        for key in keys:
            if key in values_dict:
                value = values_dict[key]
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin)
                    res[feature_name] = val
                result.append((value["event_ts"], res))
            else:
                result.append((None, None))

        return result
Beispiel #17
0
    def online_read(
        self,
        config: RepoConfig,
        table: Union[FeatureTable, FeatureView],
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:

        online_config = config.online_store
        assert isinstance(online_config, DatastoreOnlineStoreConfig)
        client = self._get_client(online_config)

        feast_project = config.project

        keys: List[Key] = []
        result: List[Tuple[Optional[datetime],
                           Optional[Dict[str, ValueProto]]]] = []
        for entity_key in entity_keys:
            document_id = compute_entity_id(entity_key)
            key = client.key("Project", feast_project, "Table", table.name,
                             "Row", document_id)
            keys.append(key)

        values = client.get_multi(keys)

        if values is not None:
            keys_missing_from_response = set(keys) - set(
                [v.key for v in values])
            values = sorted(values, key=lambda v: keys.index(v.key))
            for value in values:
                res = {}
                for feature_name, value_bin in value["values"].items():
                    val = ValueProto()
                    val.ParseFromString(value_bin)
                    res[feature_name] = val
                result.append((value["event_ts"], res))
            for missing_key_idx in sorted(
                [keys.index(k) for k in keys_missing_from_response]):
                result.insert(missing_key_idx, (None, None))

        return result
Beispiel #18
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ Values with an older event_ts should overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(-1000, "OLD"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
    )
    """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
def basic_rw_test(store: FeatureStore, view_name: str) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(entity_names=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            project=store.project,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        read_rows = provider.online_read(project=store.project,
                                         table=table,
                                         entity_keys=[entity_key])
        assert len(read_rows) == 1
        _, val = read_rows[0]
        assert val["lon"].string_val == expect_lon
        assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an older event_ts should not overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(1.1, "3.1"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
Beispiel #20
0
def test_online() -> None:
    """
    Test reading from the online store in local mode.
    """
    runner = CliRunner()
    with runner.local_repo(
            get_example_repo("example_feature_repo_1.py")) as store:
        # Write some data to two tables

        driver_locations_fv = store.get_feature_view(name="driver_locations")
        customer_profile_fv = store.get_feature_view(name="customer_profile")
        customer_driver_combined_fv = store.get_feature_view(
            name="customer_driver_combined")

        provider = store._get_provider()

        driver_key = EntityKeyProto(join_keys=["driver"],
                                    entity_values=[ValueProto(int64_val=1)])
        provider.online_write_batch(
            project=store.config.project,
            table=driver_locations_fv,
            data=[(
                driver_key,
                {
                    "lat": ValueProto(double_val=0.1),
                    "lon": ValueProto(string_val="1.0"),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(join_keys=["customer"],
                                      entity_values=[ValueProto(int64_val=5)])
        provider.online_write_batch(
            project=store.config.project,
            table=customer_profile_fv,
            data=[(
                customer_key,
                {
                    "avg_orders_day": ValueProto(float_val=1.0),
                    "name": ValueProto(string_val="John"),
                    "age": ValueProto(int64_val=3),
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        customer_key = EntityKeyProto(
            join_keys=["customer", "driver"],
            entity_values=[ValueProto(int64_val=5),
                           ValueProto(int64_val=1)],
        )
        provider.online_write_batch(
            project=store.config.project,
            table=customer_driver_combined_fv,
            data=[(
                customer_key,
                {
                    "trips": ValueProto(int64_val=7)
                },
                datetime.utcnow(),
                datetime.utcnow(),
            )],
            progress=None,
        )

        # Retrieve two features using two keys, one valid one non-existing
        result = store.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }, {
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()

        assert "driver_locations__lon" in result
        assert "customer_profile__avg_orders_day" in result
        assert "customer_profile__name" in result
        assert result["driver"] == [1, 1]
        assert result["customer"] == [5, 5]
        assert result["driver_locations__lon"] == ["1.0", "1.0"]
        assert result["customer_profile__avg_orders_day"] == [1.0, 1.0]
        assert result["customer_profile__name"] == ["John", "John"]
        assert result["customer_driver_combined__trips"] == [7, 7]

        # Ensure features are still in result when keys not found
        result = store.get_online_features(
            feature_refs=["customer_driver_combined:trips"],
            entity_rows=[{
                "driver": 0,
                "customer": 0
            }],
        ).to_dict()

        assert "customer_driver_combined__trips" in result

        # invalid table reference
        with pytest.raises(FeatureViewNotFoundException):
            store.get_online_features(
                feature_refs=["driver_locations_bad:lon"],
                entity_rows=[{
                    "driver": 1
                }],
            )

        # Create new FeatureStore object with fast cache invalidation
        cache_ttl = 1
        fs_fast_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=cache_ttl),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should download the registry and cache it permanently (or until manually refreshed)
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # Wait for registry to expire
        time.sleep(cache_ttl)

        # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file)
        with pytest.raises(FileNotFoundError):
            fs_fast_ttl.get_online_features(
                feature_refs=[
                    "driver_locations:lon",
                    "customer_profile:avg_orders_day",
                    "customer_profile:name",
                    "customer_driver_combined:trips",
                ],
                entity_rows=[{
                    "driver": 1,
                    "customer": 5
                }],
            ).to_dict()

        # Restore registry.db so that we can see if it actually reloads registry
        os.rename(store.config.registry + "_fake", store.config.registry)

        # Test if registry is actually reloaded and whether results return
        result = fs_fast_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Create a registry with infinite cache (for users that want to manually refresh the registry)
        fs_infinite_ttl = FeatureStore(config=RepoConfig(
            registry=RegistryConfig(path=store.config.registry,
                                    cache_ttl_seconds=0),
            online_store=store.config.online_store,
            project=store.config.project,
            provider=store.config.provider,
        ))

        # Should return results (and fill the registry cache)
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Wait a bit so that an arbitrary TTL would take effect
        time.sleep(2)

        # Rename the registry.db so that it cant be used for refreshes
        os.rename(store.config.registry, store.config.registry + "_fake")

        # TTL is infinite so this method should use registry cache
        result = fs_infinite_ttl.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_driver_combined:trips",
            ],
            entity_rows=[{
                "driver": 1,
                "customer": 5
            }],
        ).to_dict()
        assert result["driver_locations__lon"] == ["1.0"]
        assert result["customer_driver_combined__trips"] == [7]

        # Force registry reload (should fail because file is missing)
        with pytest.raises(FileNotFoundError):
            fs_infinite_ttl.refresh_registry()

        # Restore registry.db so that teardown works
        os.rename(store.config.registry + "_fake", store.config.registry)
def test_online() -> None:
    """
    Test reading from the online store in local mode.
    """
    runner = CliRunner()
    with runner.local_repo(get_example_repo("example_feature_repo_1.py")) as store:

        # Write some data to two tables
        registry = store._get_registry()
        table = registry.get_feature_view(
            project=store.config.project, name="driver_locations"
        )
        table_2 = registry.get_feature_view(
            project=store.config.project, name="driver_locations_2"
        )

        provider = store._get_provider()

        entity_key = EntityKeyProto(
            entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]
        )
        provider.online_write_batch(
            project=store.config.project,
            table=table,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=0.1),
                        "lon": ValueProto(string_val="1.0"),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )
            ],
            progress=None,
        )

        provider.online_write_batch(
            project=store.config.project,
            table=table_2,
            data=[
                (
                    entity_key,
                    {
                        "lat": ValueProto(double_val=2.0),
                        "lon": ValueProto(string_val="2.0"),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )
            ],
            progress=None,
        )

        # Retrieve two features using two keys, one valid one non-existing
        result = store.get_online_features(
            feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
            entity_rows=[{"driver": 1}, {"driver": 123}],
        )

        assert "driver_locations:lon" in result.to_dict()
        assert result.to_dict()["driver_locations:lon"] == ["1.0", None]
        assert result.to_dict()["driver_locations_2:lon"] == ["2.0", None]

        # invalid table reference
        with pytest.raises(ValueError):
            store.get_online_features(
                feature_refs=["driver_locations_bad:lon"], entity_rows=[{"driver": 1}],
            )
Beispiel #22
0
def test_online_to_df():
    """
    Test dataframe conversion. Make sure the response columns and rows are
    the same order as the request.
    """

    driver_ids = [1, 2, 3]
    customer_ids = [4, 5, 6]
    name = "foo"
    lon_multiply = 1.0
    lat_multiply = 0.1
    age_multiply = 10
    avg_order_day_multiply = 1.0

    runner = CliRunner()
    with runner.local_repo(get_example_repo("example_feature_repo_1.py"),
                           "bigquery") as store:
        # Write three tables to online store
        driver_locations_fv = store.get_feature_view(name="driver_locations")
        customer_profile_fv = store.get_feature_view(name="customer_profile")
        customer_driver_combined_fv = store.get_feature_view(
            name="customer_driver_combined")
        provider = store._get_provider()

        for (d, c) in zip(driver_ids, customer_ids):
            """
            driver table:
            driver driver_locations__lon  driver_locations__lat
                1                   1.0                    0.1
                2                   2.0                    0.2
                3                   3.0                    0.3
            """
            driver_key = EntityKeyProto(
                join_keys=["driver"], entity_values=[ValueProto(int64_val=d)])
            provider.online_write_batch(
                config=store.config,
                table=driver_locations_fv,
                data=[(
                    driver_key,
                    {
                        "lat": ValueProto(double_val=d * lat_multiply),
                        "lon": ValueProto(string_val=str(d * lon_multiply)),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )
            """
            customer table
            customer  customer_profile__avg_orders_day customer_profile__name  customer_profile__age
                4                               4.0                  foo4                     40
                5                               5.0                  foo5                     50
                6                               6.0                  foo6                     60
            """
            customer_key = EntityKeyProto(
                join_keys=["customer"],
                entity_values=[ValueProto(int64_val=c)])
            provider.online_write_batch(
                config=store.config,
                table=customer_profile_fv,
                data=[(
                    customer_key,
                    {
                        "avg_orders_day":
                        ValueProto(float_val=c * avg_order_day_multiply),
                        "name":
                        ValueProto(string_val=name + str(c)),
                        "age":
                        ValueProto(int64_val=c * age_multiply),
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )
            """
            customer_driver_combined table
            customer  driver  customer_driver_combined__trips
                4       1                                4
                5       2                               10
                6       3                               18
            """
            combo_keys = EntityKeyProto(
                join_keys=["customer", "driver"],
                entity_values=[
                    ValueProto(int64_val=c),
                    ValueProto(int64_val=d)
                ],
            )
            provider.online_write_batch(
                config=store.config,
                table=customer_driver_combined_fv,
                data=[(
                    combo_keys,
                    {
                        "trips": ValueProto(int64_val=c * d)
                    },
                    datetime.utcnow(),
                    datetime.utcnow(),
                )],
                progress=None,
            )

        # Get online features in dataframe
        result_df = store.get_online_features(
            feature_refs=[
                "driver_locations:lon",
                "driver_locations:lat",
                "customer_profile:avg_orders_day",
                "customer_profile:name",
                "customer_profile:age",
                "customer_driver_combined:trips",
            ],
            # Reverse the row order
            entity_rows=[{
                "driver": d,
                "customer": c
            } for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))],
        ).to_df()
        """
        Construct the expected dataframe with reversed row order like so:
        driver  customer driver_locations__lon  driver_locations__lat  customer_profile__avg_orders_day customer_profile__name  customer_profile__age  customer_driver_combined__trips
            3       6         3.0                   0.3                    6.0                               foo6                   60                     18
            2       5         2.0                   0.2                    5.0                               foo5                   50                     10
            1       4         1.0                   0.1                    4.0                               foo4                   40                     4
        """
        df_dict = {
            "driver":
            driver_ids,
            "customer":
            customer_ids,
            "driver_locations__lon":
            [str(d * lon_multiply) for d in driver_ids],
            "driver_locations__lat": [d * lat_multiply for d in driver_ids],
            "customer_profile__avg_orders_day":
            [c * avg_order_day_multiply for c in customer_ids],
            "customer_profile__name": [name + str(c) for c in customer_ids],
            "customer_profile__age": [c * age_multiply for c in customer_ids],
            "customer_driver_combined__trips":
            [d * c for (d, c) in zip(driver_ids, customer_ids)],
        }
        # Requested column order
        ordered_column = [
            "driver",
            "customer",
            "driver_locations__lon",
            "driver_locations__lat",
            "customer_profile__avg_orders_day",
            "customer_profile__name",
            "customer_profile__age",
            "customer_driver_combined__trips",
        ]
        expected_df = pd.DataFrame(
            {k: reversed(v)
             for (k, v) in df_dict.items()})
        assert_frame_equal(result_df[ordered_column], expected_df)
Beispiel #23
0
    def online_read(
        self,
        config: RepoConfig,
        table: FeatureView,
        entity_keys: List[EntityKeyProto],
        requested_features: Optional[List[str]] = None,
    ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
        """
        Retrieve feature values from the online DynamoDB store.

        Args:
            config: The RepoConfig for the current FeatureStore.
            table: Feast FeatureView.
            entity_keys: a list of entity keys that should be read from the FeatureStore.
        """
        online_config = config.online_store
        assert isinstance(online_config, DynamoDBOnlineStoreConfig)
        dynamodb_resource = self._get_dynamodb_resource(
            online_config.region, online_config.endpoint_url
        )
        table_instance = dynamodb_resource.Table(
            _get_table_name(online_config, config, table)
        )

        result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
        entity_ids = [compute_entity_id(entity_key) for entity_key in entity_keys]
        batch_size = online_config.batch_size
        entity_ids_iter = iter(entity_ids)
        while True:
            batch = list(itertools.islice(entity_ids_iter, batch_size))
            # No more items to insert
            if len(batch) == 0:
                break
            batch_entity_ids = {
                table_instance.name: {
                    "Keys": [{"entity_id": entity_id} for entity_id in batch]
                }
            }
            with tracing_span(name="remote_call"):
                response = dynamodb_resource.batch_get_item(
                    RequestItems=batch_entity_ids
                )
            response = response.get("Responses")
            table_responses = response.get(table_instance.name)
            if table_responses:
                table_responses = self._sort_dynamodb_response(
                    table_responses, entity_ids
                )
                entity_idx = 0
                for tbl_res in table_responses:
                    entity_id = tbl_res["entity_id"]
                    while entity_id != batch[entity_idx]:
                        result.append((None, None))
                        entity_idx += 1
                    res = {}
                    for feature_name, value_bin in tbl_res["values"].items():
                        val = ValueProto()
                        val.ParseFromString(value_bin.value)
                        res[feature_name] = val
                    result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
                    entity_idx += 1

            # Not all entities in a batch may have responses
            # Pad with remaining values in batch that were not found
            batch_size_nones = ((None, None),) * (len(batch) - len(result))
            result.extend(batch_size_nones)
        return result
Beispiel #24
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific  tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver_id"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """A helper function to write values and read them back"""
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver_id": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver_id"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
Beispiel #25
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6