示例#1
0
def create_bq_view_of_joined_features_and_entities(
        source: BigQuerySource, entity_source: BigQuerySource,
        entity_names: List[str]) -> BigQuerySource:
    """
    Creates BQ view that joins tables from `source` and `entity_source` with join key derived from `entity_names`.
    Returns BigQuerySource with reference to created view.
    """
    bq_client = bigquery.Client()

    source_ref = table_reference_from_string(source.bigquery_options.table_ref)
    entities_ref = table_reference_from_string(
        entity_source.bigquery_options.table_ref)

    destination_ref = bigquery.TableReference(
        bigquery.DatasetReference(source_ref.project, source_ref.dataset_id),
        f"_view_{source_ref.table_id}_{datetime.now():%Y%m%d%H%M%s}",
    )

    view = bigquery.Table(destination_ref)
    view.view_query = JOIN_TEMPLATE.format(
        entities=entities_ref,
        source=source_ref,
        entity_key=" AND ".join(
            [f"source.{e} = entities.{e}" for e in entity_names]),
    )
    view.expires = datetime.now() + timedelta(days=1)
    bq_client.create_table(view)

    return BigQuerySource(
        event_timestamp_column=source.event_timestamp_column,
        created_timestamp_column=source.created_timestamp_column,
        table_ref=f"{view.project}:{view.dataset_id}.{view.table_id}",
        field_mapping=source.field_mapping,
        date_partition_column=source.date_partition_column,
    )
示例#2
0
def stage_entities_to_bq(entity_source: pd.DataFrame, project: str,
                         dataset: str) -> BigQuerySource:
    """
    Stores given (entity) dataframe as new table in BQ. Name of the table generated based on current time.
    Table will expire in 1 day.
    Returns BigQuerySource with reference to created table.
    """
    bq_client = bigquery.Client()
    destination = bigquery.TableReference(
        bigquery.DatasetReference(project, dataset),
        f"_entities_{datetime.now():%Y%m%d%H%M%s}",
    )

    # prevent casting ns -> ms exception inside pyarrow
    entity_source["event_timestamp"] = entity_source[
        "event_timestamp"].dt.floor("ms")

    load_job: bigquery.LoadJob = bq_client.load_table_from_dataframe(
        entity_source, destination)
    load_job.result()  # wait until complete

    dest_table: bigquery.Table = bq_client.get_table(destination)
    dest_table.expires = datetime.now() + timedelta(days=1)
    bq_client.update_table(dest_table, fields=["expires"])

    return BigQuerySource(
        event_timestamp_column="event_timestamp",
        table_ref=
        f"{destination.project}:{destination.dataset_id}.{destination.table_id}",
    )
示例#3
0
def test_ingest_into_bq(
    feast_client: Client,
    customer_entity: Entity,
    driver_entity: Entity,
    bq_dataframe: pd.DataFrame,
    bq_dataset: str,
    pytestconfig,
):
    bq_project = pytestconfig.getoption("bq_project")
    bq_table_id = f"bq_staging_{datetime.now():%Y%m%d%H%M%s}"
    ft = FeatureTable(
        name="basic_featuretable",
        entities=["driver_id", "customer_id"],
        features=[
            Feature(name="dev_feature_float", dtype=ValueType.FLOAT),
            Feature(name="dev_feature_string", dtype=ValueType.STRING),
        ],
        max_age=Duration(seconds=3600),
        batch_source=BigQuerySource(
            table_ref=f"{bq_project}:{bq_dataset}.{bq_table_id}",
            event_timestamp_column="datetime",
            created_timestamp_column="timestamp",
        ),
    )

    # ApplyEntity
    feast_client.apply(customer_entity)
    feast_client.apply(driver_entity)

    # ApplyFeatureTable
    feast_client.apply(ft)
    feast_client.ingest(ft, bq_dataframe, timeout=120)

    bq_client = bigquery.Client(project=bq_project)

    # Poll BQ for table until the table has been created
    def try_get_table():
        try:
            table = bq_client.get_table(
                bigquery.TableReference(
                    bigquery.DatasetReference(bq_project, bq_dataset), bq_table_id
                )
            )
        except NotFound:
            return None, False
        else:
            return table, True

    wait_retry_backoff(
        retry_fn=try_get_table,
        timeout_secs=30,
        timeout_msg="Timed out trying to get bigquery table",
    )

    query_string = f"SELECT * FROM `{bq_project}.{bq_dataset}.{bq_table_id}`"

    job = bq_client.query(query_string)
    query_df = job.to_dataframe()

    assert_frame_equal(query_df, bq_dataframe)
示例#4
0
def simple_bq_source_using_query_arg(df, event_timestamp_column="") -> BigQuerySource:
    bq_source_using_table_ref = simple_bq_source_using_table_ref_arg(
        df, event_timestamp_column
    )
    return BigQuerySource(
        query=f"SELECT * FROM {bq_source_using_table_ref.table_ref}",
        event_timestamp_column=event_timestamp_column,
    )
    def test_bigquery_query_to_datastore_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, 0.3],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = f"{self.gcp_project}.{self.bigquery_dataset}.query_correctness_{int(time.time())}"
        query = f"SELECT * FROM `{table_id}`"
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_query_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
                query=query,
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project=f"test_bq_query_correctness_{int(time.time())}",
            provider="gcp",
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            [fv.name],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        response_dict = fs.get_online_features([f"{fv.name}:value"],
                                               [{
                                                   "driver_id": 1
                                               }]).to_dict()
        assert abs(response_dict[f"{fv.name}:value"][0] - 0.3) < 1e-6
def prep_bq_fs_and_fv(
    bq_source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]:
    client = bigquery.Client()
    gcp_project = client.project
    bigquery_dataset = "test_ingestion"
    dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}")
    client.create_dataset(dataset, exists_ok=True)
    dataset.default_table_expiration_ms = (1000 * 60 * 60 * 24 * 14
                                           )  # 2 weeks in milliseconds
    client.update_dataset(dataset, ["default_table_expiration_ms"])

    df = create_dataset()

    job_config = bigquery.LoadJobConfig()
    table_ref = f"{gcp_project}.{bigquery_dataset}.{bq_source_type}_correctness_{int(time.time())}"
    query = f"SELECT * FROM `{table_ref}`"
    job = client.load_table_from_dataframe(df,
                                           table_ref,
                                           job_config=job_config)
    job.result()

    bigquery_source = BigQuerySource(
        table_ref=table_ref if bq_source_type == "table" else None,
        query=query if bq_source_type == "query" else None,
        event_timestamp_column="ts",
        created_timestamp_column="created_ts",
        date_partition_column="",
        field_mapping={
            "ts_1": "ts",
            "id": "driver_id"
        },
    )

    fv = get_feature_view(bigquery_source)
    e = Entity(
        name="driver",
        description="id for driver",
        join_key="driver_id",
        value_type=ValueType.INT32,
    )
    with tempfile.TemporaryDirectory() as repo_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
            provider="gcp",
            online_store=DatastoreOnlineStoreConfig(
                namespace="integration_test"),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv, e])

        yield fs, fv
示例#7
0
def bq_featuretable(bq_table_id):
    batch_source = BigQuerySource(
        table_ref=bq_table_id,
        timestamp_column="datetime",
    )
    return FeatureTable(
        name="basic_featuretable",
        entities=["driver_id", "customer_id"],
        features=[
            Feature(name="dev_feature_float", dtype=ValueType.FLOAT),
            Feature(name="dev_feature_string", dtype=ValueType.STRING),
        ],
        max_age=Duration(seconds=3600),
        batch_source=batch_source,
    )
示例#8
0
def create_bq_view_of_joined_features_and_entities(
    source: BigQuerySource, entity_source: BigQuerySource, entity_names: List[str]
) -> BigQuerySource:
    """
    Creates BQ view that joins tables from `source` and `entity_source` with join key derived from `entity_names`.
Returns BigQuerySource with reference to created view. The BQ view will be created in the same BQ dataset as `entity_source`.
    """
    from google.cloud import bigquery

    bq_client = bigquery.Client()

    source_ref = table_reference_from_string(source.bigquery_options.table_ref)
    entities_ref = table_reference_from_string(entity_source.bigquery_options.table_ref)

    destination_ref = bigquery.TableReference(
        bigquery.DatasetReference(entities_ref.project, entities_ref.dataset_id),
        f"_view_{source_ref.table_id}_{datetime.now():%Y%m%d%H%M%s}",
    )

    view = bigquery.Table(destination_ref)

    join_template = """
    SELECT source.* FROM
    `{entities.project}.{entities.dataset_id}.{entities.table_id}` entities
    JOIN
    `{source.project}.{source.dataset_id}.{source.table_id}` source
    ON
    ({entity_key})"""

    view.view_query = join_template.format(
        entities=entities_ref,
        source=source_ref,
        entity_key=" AND ".join([f"source.{e} = entities.{e}" for e in entity_names]),
    )
    view.expires = datetime.now() + timedelta(days=1)
    bq_client.create_table(view)

    return BigQuerySource(
        event_timestamp_column=source.event_timestamp_column,
        created_timestamp_column=source.created_timestamp_column,
        table_ref=f"{view.project}:{view.dataset_id}.{view.table_id}",
        field_mapping=source.field_mapping,
        date_partition_column=source.date_partition_column,
    )
示例#9
0
def simple_bq_source_using_table_ref_arg(df,
                                         event_timestamp_column=None
                                         ) -> BigQuerySource:
    client = bigquery.Client()
    gcp_project = client.project
    bigquery_dataset = "ds"
    dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}")
    client.create_dataset(dataset, exists_ok=True)
    dataset.default_table_expiration_ms = (
        1000 * 60 *
        60  # 60 minutes in milliseconds (seems to be minimum limit for gcloud)
    )
    client.update_dataset(dataset, ["default_table_expiration_ms"])
    table_ref = f"{gcp_project}.{bigquery_dataset}.table_1"

    job = client.load_table_from_dataframe(df,
                                           table_ref,
                                           job_config=bigquery.LoadJobConfig())
    job.result()

    return BigQuerySource(
        table_ref=table_ref,
        event_timestamp_column=event_timestamp_column,
    )
示例#10
0
def test_historical_features_from_bigquery_sources(
    provider_type, infer_event_timestamp_col
):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date
        )
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date
        )
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(customer_source)

        driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="default",
                    provider="local",
                    online_store=SqliteOnlineStoreConfig(
                        path=os.path.join(temp_dir, "online_store.db"),
                    ),
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(
                        type="bigquery", dataset="foo"
                    ),
                )
            )
        else:
            raise Exception("Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (
            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
            if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns
            else "e_ts"
        )
        expected_df = get_expected_training_df(
            customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        if provider_type == "gcp_custom_offline_config":
            # Make sure that custom dataset name is being used from the offline_store config
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            # If the custom dataset name isn't provided in the config, use default `feast` name
            assertpy.assert_that(job_from_df.query).contains("feast.entity_df")

        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )
示例#11
0
def test_historical_features_from_bigquery_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(type="bigquery",
                                                         dataset="foo"),
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           in orders_df.columns else "e_ts")
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
            event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        start_time = datetime.utcnow()
        actual_df_from_sql_entities = job_from_sql.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_sql_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_sql_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )

        table_from_sql_entities = job_from_sql.to_arrow()
        assert_frame_equal(actual_df_from_sql_entities,
                           table_from_sql_entities.to_pandas())

        timestamp_column = ("e_ts" if infer_event_timestamp_col else
                            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

        entity_df_query_with_invalid_join_key = (
            f"select order_id, driver_id, customer_id as customer, "
            f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}"
        )
        # Rename the join key; this should now raise an error.
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=entity_df_query_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        # Rename the join key; this should now raise an error.
        orders_df_with_invalid_join_key = orders_df.rename(
            {"customer_id": "customer"}, axis="columns")
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=orders_df_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        # Make sure that custom dataset name is being used from the offline_store config
        if provider_type == "gcp_custom_offline_config":
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            assertpy.assert_that(
                job_from_df.query).contains(f"{bigquery_dataset}.entity_df")

        start_time = datetime.utcnow()
        actual_df_from_df_entities = job_from_df.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_df_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_df_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )
        table_from_df_entities = job_from_df.to_arrow()
        assert_frame_equal(actual_df_from_df_entities,
                           table_from_df_entities.to_pandas())
def test_historical_features_from_bigquery_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = f"test_hist_retrieval_{int(time.time())}"

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver", value_type=ValueType.INT64)
        customer = Entity(name="customer", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db"), )),
        ))
        store.apply([driver, customer, driver_fv, customer_fv])

        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )
示例#13
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6