Пример #1
0
def benchmark_writes():
    project_id = "test" + "".join(
        random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project=project_id,
                provider="gcp",
            )
        )

        # This is just to set data source to something, we're not reading from parquet source here.
        parquet_path = os.path.join(temp_dir, "data.parquet")

        driver = Entity(name="driver_id", value_type=ValueType.INT64)
        table = create_driver_hourly_stats_feature_view(
            create_driver_hourly_stats_source(parquet_path=parquet_path)
        )
        store.apply([table, driver])

        provider = store._get_provider()

        end_date = datetime.utcnow()
        start_date = end_date - timedelta(days=14)
        customers = list(range(100))
        data = create_driver_hourly_stats_df(customers, start_date, end_date)

        # Show the data for reference
        print(data)
        proto_data = _convert_arrow_to_proto(
            pa.Table.from_pandas(data), table, ["driver_id"]
        )

        # Write it
        with tqdm(total=len(proto_data)) as progress:
            provider.online_write_batch(
                project=store.project,
                table=table,
                data=proto_data,
                progress=progress.update,
            )

        registry_tables = store.list_feature_views()
        registry_entities = store.list_entities()
        provider.teardown_infra(
            store.project, tables=registry_tables, entities=registry_entities
        )
Пример #2
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ Values with an older event_ts should overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(-1000, "OLD"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )

    # Note: This behavior has changed for performance. We should test that older
    # value can't overwrite over a newer value once we add the respective flag
    """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
    )
    """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
Пример #3
0
def basic_rw_test(store: FeatureStore,
                  view_name: str,
                  feature_service_name: Optional[str] = None) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific  tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(join_keys=["driver_id"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """A helper function to write values and read them back"""
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            config=store.config,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        if feature_service_name:
            entity_dict = {"driver_id": 1}
            feature_service = store.get_feature_service(feature_service_name)
            features = store.get_online_features(features=feature_service,
                                                 entity_rows=[entity_dict
                                                              ]).to_dict()
            assert len(features["driver_id"]) == 1
            assert features["lon"][0] == expect_lon
            assert abs(features["lat"][0] - expect_lat) < 1e-6
        else:
            read_rows = provider.online_read(config=store.config,
                                             table=table,
                                             entity_keys=[entity_key])
            assert len(read_rows) == 1
            _, val = read_rows[0]
            assert val["lon"].string_val == expect_lon
            assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
Пример #4
0
def basic_rw_test(store: FeatureStore, view_name: str) -> None:
    """
    This is a provider-independent test suite for reading and writing from the online store, to
    be used by provider-specific tests.
    """
    table = store.get_feature_view(name=view_name)

    provider = store._get_provider()

    entity_key = EntityKeyProto(entity_names=["driver"],
                                entity_values=[ValueProto(int64_val=1)])

    def _driver_rw_test(event_ts, created_ts, write, expect_read):
        """ A helper function to write values and read them back """
        write_lat, write_lon = write
        expect_lat, expect_lon = expect_read
        provider.online_write_batch(
            project=store.project,
            table=table,
            data=[(
                entity_key,
                {
                    "lat": ValueProto(double_val=write_lat),
                    "lon": ValueProto(string_val=write_lon),
                },
                event_ts,
                created_ts,
            )],
            progress=None,
        )

        read_rows = provider.online_read(project=store.project,
                                         table=table,
                                         entity_keys=[entity_key])
        assert len(read_rows) == 1
        _, val = read_rows[0]
        assert val["lon"].string_val == expect_lon
        assert abs(val["lat"].double_val - expect_lat) < 1e-6

    """ 1. Basic test: write value, read it back """

    time_1 = datetime.utcnow()
    _driver_rw_test(event_ts=time_1,
                    created_ts=time_1,
                    write=(1.1, "3.1"),
                    expect_read=(1.1, "3.1"))
    """ Values with an older event_ts should not overwrite newer ones """
    time_2 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 - timedelta(hours=1),
        created_ts=time_2,
        write=(-1000, "OLD"),
        expect_read=(1.1, "3.1"),
    )
    """ Values with an new event_ts should overwrite older ones """
    time_3 = datetime.utcnow()
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3,
        write=(1123, "NEWER"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 - timedelta(hours=1),
        write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
        expect_read=(1123, "NEWER"),
    )
    """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """
    _driver_rw_test(
        event_ts=time_1 + timedelta(hours=1),
        created_ts=time_3 + timedelta(hours=1),
        write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
        expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
    )
Пример #5
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6