Example #1
0
    def __init__(
        self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None,
    ):
        self.repo_path = repo_path
        if repo_path is not None and config is not None:
            raise ValueError("You cannot specify both repo_path and config")
        if config is not None:
            self.config = config
        elif repo_path is not None:
            self.config = load_repo_config(Path(repo_path))
        else:
            self.config = RepoConfig(
                registry="./registry.db",
                project="default",
                provider="local",
                online_store=OnlineStoreConfig(
                    local=LocalOnlineStoreConfig(path="online_store.db")
                ),
            )

        registry_config = self.config.get_registry_config()
        self._registry = Registry(
            registry_path=registry_config.path,
            cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
        )
        self._tele = Telemetry()
Example #2
0
def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
    with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
        df = create_dataset()
        f.close()
        df.to_parquet(f.name)
        file_source = FileSource(
            file_format=ParquetFormat(),
            file_url=f"file://{f.name}",
            event_timestamp_column="ts",
            created_timestamp_column="created_ts",
            date_partition_column="",
            field_mapping={
                "ts_1": "ts",
                "id": "driver_id"
            },
        )
        fv = get_feature_view(file_source)
        with tempfile.TemporaryDirectory(
        ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
            config = RepoConfig(
                registry=str(Path(repo_dir_name) / "registry.db"),
                project=
                f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
                provider="local",
                online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                    path=str(Path(data_dir_name) / "online_store.db"))),
            )
            fs = FeatureStore(config=config)
            fs.apply([fv])

            yield fs, fv
Example #3
0
def prep_dynamodb_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
    with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
        df = create_dataset()
        f.close()
        df.to_parquet(f.name)
        file_source = FileSource(
            file_format=ParquetFormat(),
            file_url=f"file://{f.name}",
            event_timestamp_column="ts",
            created_timestamp_column="created_ts",
            date_partition_column="",
            field_mapping={"ts_1": "ts", "id": "driver_id"},
        )
        fv = get_feature_view(file_source)
        e = Entity(
            name="driver",
            description="id for driver",
            join_key="driver_id",
            value_type=ValueType.INT32,
        )
        with tempfile.TemporaryDirectory() as repo_dir_name:
            config = RepoConfig(
                registry=str(Path(repo_dir_name) / "registry.db"),
                project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
                provider="aws",
                online_store=DynamoDBOnlineStoreConfig(region="us-west-2"),
                offline_store=FileOfflineStoreConfig(),
            )
            fs = FeatureStore(config=config)
            fs.apply([fv, e])

            yield fs, fv
def repo_config():
    return RepoConfig(
        registry=REGISTRY,
        project=PROJECT,
        provider=PROVIDER,
        online_store=DynamoDBOnlineStoreConfig(region=REGION),
        offline_store=FileOfflineStoreConfig(),
    )
Example #5
0
def registry_dump(repo_config: RepoConfig, repo_path: Path):
    """For debugging only: output contents of the metadata registry"""
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(registry_config=registry_config, repo_path=repo_path)
    registry_dict = registry.to_dict(project=project)

    click.echo(json.dumps(registry_dict, indent=2, sort_keys=True))
    def test_bigquery_query_to_datastore_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, 0.3],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = f"{self.gcp_project}.{self.bigquery_dataset}.query_correctness_{int(time.time())}"
        query = f"SELECT * FROM `{table_id}`"
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_query_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
                query=query,
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project=f"test_bq_query_correctness_{int(time.time())}",
            provider="gcp",
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            [fv.name],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        response_dict = fs.get_online_features([f"{fv.name}:value"],
                                               [{
                                                   "driver_id": 1
                                               }]).to_dict()
        assert abs(response_dict[f"{fv.name}:value"][0] - 0.3) < 1e-6
Example #7
0
def feature_store_with_s3_registry():
    return FeatureStore(config=RepoConfig(
        registry=
        f"s3://feast-integration-tests/registries/{int(time.time() * 1000)}/registry.db",
        project="default",
        provider="aws",
        online_store=DynamoDBOnlineStoreConfig(region="us-west-2"),
        offline_store=FileOfflineStoreConfig(),
    ))
Example #8
0
def test_apply_remote_repo():
    fd, registry_path = mkstemp()
    fd, online_store_path = mkstemp()
    return FeatureStore(config=RepoConfig(
        registry=registry_path,
        project="default",
        provider="local",
        online_store=SqliteOnlineStoreConfig(path=online_store_path),
    ))
Example #9
0
def feature_store_with_local_registry():
    fd, registry_path = mkstemp()
    fd, online_store_path = mkstemp()
    return FeatureStore(config=RepoConfig(
        registry=registry_path,
        project="default",
        provider="local",
        online_store=SqliteOnlineStoreConfig(path=online_store_path),
    ))
Example #10
0
def registry_dump(repo_config: RepoConfig, repo_path: Path):
    """ For debugging only: output contents of the metadata registry """
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(registry_config=registry_config, repo_path=repo_path)

    for entity in registry.list_entities(project=project):
        print(entity)
    for feature_view in registry.list_feature_views(project=project):
        print(feature_view)
Example #11
0
 def feature_store_with_local_registry(self):
     fd, registry_path = mkstemp()
     fd, online_store_path = mkstemp()
     return FeatureStore(config=RepoConfig(
         metadata_store=registry_path,
         project="default",
         provider="local",
         online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
             path=online_store_path)),
     ))
def prep_bq_fs_and_fv(
    bq_source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]:
    client = bigquery.Client()
    gcp_project = client.project
    bigquery_dataset = "test_ingestion"
    dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}")
    client.create_dataset(dataset, exists_ok=True)
    dataset.default_table_expiration_ms = (1000 * 60 * 60 * 24 * 14
                                           )  # 2 weeks in milliseconds
    client.update_dataset(dataset, ["default_table_expiration_ms"])

    df = create_dataset()

    job_config = bigquery.LoadJobConfig()
    table_ref = f"{gcp_project}.{bigquery_dataset}.{bq_source_type}_correctness_{int(time.time_ns())}"
    query = f"SELECT * FROM `{table_ref}`"
    job = client.load_table_from_dataframe(df,
                                           table_ref,
                                           job_config=job_config)
    job.result()

    bigquery_source = BigQuerySource(
        table_ref=table_ref if bq_source_type == "table" else None,
        query=query if bq_source_type == "query" else None,
        event_timestamp_column="ts",
        created_timestamp_column="created_ts",
        date_partition_column="",
        field_mapping={
            "ts_1": "ts",
            "id": "driver_id"
        },
    )

    fv = driver_feature_view(bigquery_source)
    e = Entity(
        name="driver",
        description="id for driver",
        join_key="driver_id",
        value_type=ValueType.INT32,
    )
    with tempfile.TemporaryDirectory() as repo_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
            provider="gcp",
            online_store=DatastoreOnlineStoreConfig(
                namespace="integration_test"),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv, e])

        yield fs, fv

        fs.teardown()
Example #13
0
def teardown(repo_config: RepoConfig, repo_path: Path):
    registry_config = repo_config.get_registry_config()
    registry = Registry(
        registry_path=registry_config.path,
        cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
    )
    project = repo_config.project
    registry_tables: List[Union[FeatureTable, FeatureView]] = []
    registry_tables.extend(registry.list_feature_tables(project=project))
    registry_tables.extend(registry.list_feature_views(project=project))
    infra_provider = get_provider(repo_config)
    infra_provider.teardown_infra(project, tables=registry_tables)
Example #14
0
def registry_dump(repo_config: RepoConfig):
    """ For debugging only: output contents of the metadata registry """
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(
        registry_path=registry_config.path,
        cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
    )

    for entity in registry.list_entities(project=project):
        print(entity)
    for table in registry.list_feature_tables(project=project):
        print(table)
Example #15
0
def benchmark_writes():
    project_id = "test" + "".join(
        random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project=project_id,
                provider="gcp",
            )
        )

        # This is just to set data source to something, we're not reading from parquet source here.
        parquet_path = os.path.join(temp_dir, "data.parquet")

        driver = Entity(name="driver_id", value_type=ValueType.INT64)
        table = create_driver_hourly_stats_feature_view(
            create_driver_hourly_stats_source(parquet_path=parquet_path)
        )
        store.apply([table, driver])

        provider = store._get_provider()

        end_date = datetime.utcnow()
        start_date = end_date - timedelta(days=14)
        customers = list(range(100))
        data = create_driver_hourly_stats_df(customers, start_date, end_date)

        # Show the data for reference
        print(data)
        proto_data = _convert_arrow_to_proto(
            pa.Table.from_pandas(data), table, ["driver_id"]
        )

        # Write it
        with tqdm(total=len(proto_data)) as progress:
            provider.online_write_batch(
                project=store.project,
                table=table,
                data=proto_data,
                progress=progress.update,
            )

        registry_tables = store.list_feature_views()
        registry_entities = store.list_entities()
        provider.teardown_infra(
            store.project, tables=registry_tables, entities=registry_entities
        )
    def _build_feast_feature_store(self):
        os.environ["FEAST_S3_ENDPOINT_URL"] = aws.S3_ENDPOINT.get()
        os.environ["AWS_ACCESS_KEY_ID"] = aws.S3_ACCESS_KEY_ID.get()
        os.environ["AWS_SECRET_ACCESS_KEY"] = aws.S3_SECRET_ACCESS_KEY.get()

        config = RepoConfig(
            registry=f"s3://{self.config.s3_bucket}/{self.config.registry_path}",
            project=self.config.project,
            # Notice the use of a custom provider.
            provider="custom_provider.provider.FlyteCustomProvider",
            offline_store=FileOfflineStoreConfig(),
            online_store=SqliteOnlineStoreConfig(path=self.config.online_store_path),
        )
        return FeastFeatureStore(config=config)
Example #17
0
def registry_dump(repo_config: RepoConfig, repo_path: Path):
    """ For debugging only: output contents of the metadata registry """
    from colorama import Fore, Style

    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(registry_config=registry_config, repo_path=repo_path)
    registry_dict = registry.to_dict(project=project)

    warning = (
        "Warning: The registry-dump command is for debugging only and may contain "
        "breaking changes in the future. No guarantees are made on this interface."
    )
    click.echo(f"{Style.BRIGHT}{Fore.YELLOW}{warning}{Style.RESET_ALL}")
    click.echo(json.dumps(registry_dict, indent=2))
Example #18
0
    def feature_store_with_gcs_registry(self):
        from google.cloud import storage

        storage_client = storage.Client()
        bucket_name = f"feast-registry-test-{int(time.time())}"
        bucket = storage_client.bucket(bucket_name)
        bucket = storage_client.create_bucket(bucket)
        bucket.add_lifecycle_delete_rule(
            age=14)  # delete buckets automatically after 14 days
        bucket.patch()
        bucket.blob("metadata.db")

        return FeatureStore(config=RepoConfig(
            metadata_store=f"gs://{bucket_name}/metadata.db",
            project="default",
            provider="gcp",
        ))
Example #19
0
 def __init__(
     self,
     repo_path: Optional[str] = None,
     config: Optional[RepoConfig] = None,
 ):
     if repo_path is not None and config is not None:
         raise ValueError("You cannot specify both repo_path and config")
     if config is not None:
         self.config = config
     elif repo_path is not None:
         self.config = load_repo_config(Path(repo_path))
     else:
         self.config = RepoConfig(
             metadata_store="./metadata.db",
             project="default",
             provider="local",
             online_store=OnlineStoreConfig(
                 local=LocalOnlineStoreConfig("online_store.db")),
         )
def prep_redis_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
    with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
        df = create_dataset()
        f.close()
        df.to_parquet(f.name)
        file_source = FileSource(
            file_format=ParquetFormat(),
            path=f"file://{f.name}",
            event_timestamp_column="ts",
            created_timestamp_column="created_ts",
            date_partition_column="",
            field_mapping={
                "ts_1": "ts",
                "id": "driver_id"
            },
        )
        fv = driver_feature_view(file_source)
        e = Entity(
            name="driver",
            description="id for driver",
            join_key="driver_id",
            value_type=ValueType.INT32,
        )
        project = f"test_redis_correctness_{str(uuid.uuid4()).replace('-', '')}"
        print(f"Using project: {project}")
        with tempfile.TemporaryDirectory() as repo_dir_name:
            config = RepoConfig(
                registry=str(Path(repo_dir_name) / "registry.db"),
                project=project,
                provider="local",
                online_store=RedisOnlineStoreConfig(
                    type="redis",
                    redis_type=RedisType.redis,
                    connection_string="localhost:6379,db=0",
                ),
            )
            fs = FeatureStore(config=config)
            fs.apply([fv, e])

            yield fs, fv

            fs.teardown()
Example #21
0
def apply_total(repo_config: RepoConfig, repo_path: Path):
    from colorama import Fore, Style

    os.chdir(repo_path)
    sys.path.append("")
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(
        registry_path=registry_config.path,
        repo_path=repo_path,
        cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
    )
    registry._initialize_registry()
    sys.dont_write_bytecode = True
    repo = parse_repo(repo_path)
    sys.dont_write_bytecode = False
    for entity in repo.entities:
        registry.apply_entity(entity, project=project)
        click.echo(
            f"Registered entity {Style.BRIGHT + Fore.GREEN}{entity.name}{Style.RESET_ALL}"
        )

    repo_table_names = set(t.name for t in repo.feature_tables)

    for t in repo.feature_views:
        repo_table_names.add(t.name)

    tables_to_delete = []
    for registry_table in registry.list_feature_tables(project=project):
        if registry_table.name not in repo_table_names:
            tables_to_delete.append(registry_table)

    views_to_delete = []
    for registry_view in registry.list_feature_views(project=project):
        if registry_view.name not in repo_table_names:
            views_to_delete.append(registry_view)

    # Delete tables that should not exist
    for registry_table in tables_to_delete:
        registry.delete_feature_table(registry_table.name, project=project)
        click.echo(
            f"Deleted feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL} from registry"
        )

    # Create tables that should
    for table in repo.feature_tables:
        registry.apply_feature_table(table, project)
        click.echo(
            f"Registered feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL}"
        )

    # Delete views that should not exist
    for registry_view in views_to_delete:
        registry.delete_feature_view(registry_view.name, project=project)
        click.echo(
            f"Deleted feature view {Style.BRIGHT + Fore.GREEN}{registry_view.name}{Style.RESET_ALL} from registry"
        )

    # Create views that should
    for view in repo.feature_views:
        registry.apply_feature_view(view, project)
        click.echo(
            f"Registered feature view {Style.BRIGHT + Fore.GREEN}{view.name}{Style.RESET_ALL}"
        )

    infra_provider = get_provider(repo_config, repo_path)

    all_to_delete: List[Union[FeatureTable, FeatureView]] = []
    all_to_delete.extend(tables_to_delete)
    all_to_delete.extend(views_to_delete)

    all_to_keep: List[Union[FeatureTable, FeatureView]] = []
    all_to_keep.extend(repo.feature_tables)
    all_to_keep.extend(repo.feature_views)

    for name in [view.name for view in repo.feature_tables
                 ] + [table.name for table in repo.feature_views]:
        click.echo(
            f"Deploying infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}"
        )
    for name in [view.name for view in views_to_delete
                 ] + [table.name for table in tables_to_delete]:
        click.echo(
            f"Removing infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}"
        )
    infra_provider.update_infra(
        project,
        tables_to_delete=all_to_delete,
        tables_to_keep=all_to_keep,
        partial=False,
    )
def prep_redshift_fs_and_fv(
    source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]:
    client = aws_utils.get_redshift_data_client("us-west-2")
    s3 = aws_utils.get_s3_resource("us-west-2")

    df = create_dataset()

    table_name = f"test_ingestion_{source_type}_correctness_{int(time.time_ns())}_{random.randint(1000, 9999)}"

    offline_store = RedshiftOfflineStoreConfig(
        cluster_id="feast-integration-tests",
        region="us-west-2",
        user="******",
        database="feast",
        s3_staging_location=
        "s3://feast-integration-tests/redshift/tests/ingestion",
        iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
    )

    aws_utils.upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{table_name}.parquet",
        offline_store.iam_role,
        table_name,
        df,
    )

    redshift_source = RedshiftSource(
        table=table_name if source_type == "table" else None,
        query=f"SELECT * FROM {table_name}"
        if source_type == "query" else None,
        event_timestamp_column="ts",
        created_timestamp_column="created_ts",
        date_partition_column="",
        field_mapping={
            "ts_1": "ts",
            "id": "driver_id"
        },
    )

    fv = driver_feature_view(redshift_source)
    e = Entity(
        name="driver",
        description="id for driver",
        join_key="driver_id",
        value_type=ValueType.INT32,
    )
    with tempfile.TemporaryDirectory(
    ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
            provider="local",
            online_store=SqliteOnlineStoreConfig(
                path=str(Path(data_dir_name) / "online_store.db")),
            offline_store=offline_store,
        )
        fs = FeatureStore(config=config)
        fs.apply([fv, e])

        yield fs, fv

        fs.teardown()

    # Clean up the uploaded Redshift table
    aws_utils.execute_redshift_statement(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        f"DROP TABLE {table_name}",
    )
Example #23
0
def apply_total(repo_config: RepoConfig, repo_path: Path):
    os.chdir(repo_path)
    sys.path.append("")
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    registry = Registry(
        registry_path=registry_config.path,
        cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
    )
    repo = parse_repo(repo_path)

    for entity in repo.entities:
        registry.apply_entity(entity, project=project)

    repo_table_names = set(t.name for t in repo.feature_tables)

    for t in repo.feature_views:
        repo_table_names.add(t.name)

    tables_to_delete = []
    for registry_table in registry.list_feature_tables(project=project):
        if registry_table.name not in repo_table_names:
            tables_to_delete.append(registry_table)

    views_to_delete = []
    for registry_view in registry.list_feature_views(project=project):
        if registry_view.name not in repo_table_names:
            views_to_delete.append(registry_view)

    # Delete tables that should not exist
    for registry_table in tables_to_delete:
        registry.delete_feature_table(registry_table.name, project=project)

    # Create tables that should
    for table in repo.feature_tables:
        registry.apply_feature_table(table, project)

    # Delete views that should not exist
    for registry_view in views_to_delete:
        registry.delete_feature_view(registry_view.name, project=project)

    # Create views that should
    for view in repo.feature_views:
        registry.apply_feature_view(view, project)

    infra_provider = get_provider(repo_config)

    all_to_delete: List[Union[FeatureTable, FeatureView]] = []
    all_to_delete.extend(tables_to_delete)
    all_to_delete.extend(views_to_delete)

    all_to_keep: List[Union[FeatureTable, FeatureView]] = []
    all_to_keep.extend(repo.feature_tables)
    all_to_keep.extend(repo.feature_views)

    infra_provider.update_infra(
        project,
        tables_to_delete=all_to_delete,
        tables_to_keep=all_to_keep,
        partial=False,
    )

    print("Done!")
Example #24
0
class FeatureStore:
    """
    A FeatureStore object is used to define, create, and retrieve features.
    """

    config: RepoConfig
    repo_path: Optional[str]
    _registry: Registry

    def __init__(
        self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None,
    ):
        self.repo_path = repo_path
        if repo_path is not None and config is not None:
            raise ValueError("You cannot specify both repo_path and config")
        if config is not None:
            self.config = config
        elif repo_path is not None:
            self.config = load_repo_config(Path(repo_path))
        else:
            self.config = RepoConfig(
                registry="./registry.db",
                project="default",
                provider="local",
                online_store=OnlineStoreConfig(
                    local=LocalOnlineStoreConfig(path="online_store.db")
                ),
            )

        registry_config = self.config.get_registry_config()
        self._registry = Registry(
            registry_path=registry_config.path,
            cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
        )
        self._tele = Telemetry()

    def version(self) -> str:
        """Returns the version of the current Feast SDK/CLI"""

        return get_version()

    @property
    def project(self) -> str:
        return self.config.project

    def _get_provider(self) -> Provider:
        return get_provider(self.config)

    def refresh_registry(self):
        """Fetches and caches a copy of the feature registry in memory.

        Explicitly calling this method allows for direct control of the state of the registry cache. Every time this
        method is called the complete registry state will be retrieved from the remote registry store backend
        (e.g., GCS, S3), and the cache timer will be reset. If refresh_registry() is run before get_online_features()
        is called, then get_online_feature() will use the cached registry instead of retrieving (and caching) the
        registry itself.

        Additionally, the TTL for the registry cache can be set to infinity (by setting it to 0), which means that
        refresh_registry() will become the only way to update the cached registry. If the TTL is set to a value
        greater than 0, then once the cache becomes stale (more time than the TTL has passed), a new cache will be
        downloaded synchronously, which may increase latencies if the triggering method is get_online_features()
        """
        self._tele.log("refresh_registry")

        registry_config = self.config.get_registry_config()
        self._registry = Registry(
            registry_path=registry_config.path,
            cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
        )
        self._registry.refresh()

    def list_entities(self, allow_cache: bool = False) -> List[Entity]:
        """
        Retrieve a list of entities from the registry

        Args:
            allow_cache (bool): Whether to allow returning entities from a cached registry

        Returns:
            List of entities
        """
        self._tele.log("list_entities")

        return self._registry.list_entities(self.project, allow_cache=allow_cache)

    def list_feature_views(self) -> List[FeatureView]:
        """
        Retrieve a list of feature views from the registry

        Returns:
            List of feature views
        """
        self._tele.log("list_feature_views")

        return self._registry.list_feature_views(self.project)

    def get_entity(self, name: str) -> Entity:
        """
        Retrieves an entity.

        Args:
            name: Name of entity

        Returns:
            Returns either the specified entity, or raises an exception if
            none is found
        """
        self._tele.log("get_entity")

        return self._registry.get_entity(name, self.project)

    def get_feature_view(self, name: str) -> FeatureView:
        """
        Retrieves a feature view.

        Args:
            name: Name of feature view

        Returns:
            Returns either the specified feature view, or raises an exception if
            none is found
        """
        self._tele.log("get_feature_view")

        return self._registry.get_feature_view(name, self.project)

    def delete_feature_view(self, name: str):
        """
        Deletes a feature view or raises an exception if not found.

        Args:
            name: Name of feature view
        """
        self._tele.log("delete_feature_view")

        return self._registry.delete_feature_view(name, self.project)

    def apply(self, objects: List[Union[FeatureView, Entity]]):
        """Register objects to metadata store and update related infrastructure.

        The apply method registers one or more definitions (e.g., Entity, FeatureView) and registers or updates these
        objects in the Feast registry. Once the registry has been updated, the apply method will update related
        infrastructure (e.g., create tables in an online store) in order to reflect these new definitions. All
        operations are idempotent, meaning they can safely be rerun.

        Args: objects (List[Union[FeatureView, Entity]]): A list of FeatureView or Entity objects that should be
            registered

        Examples:
            Register a single Entity and FeatureView.
            >>> from feast.feature_store import FeatureStore
            >>> from feast import Entity, FeatureView, Feature, ValueType, FileSource
            >>> from datetime import timedelta
            >>>
            >>> fs = FeatureStore()
            >>> customer_entity = Entity(name="customer", value_type=ValueType.INT64, description="customer entity")
            >>> customer_feature_view = FeatureView(
            >>>     name="customer_fv",
            >>>     entities=["customer"],
            >>>     features=[Feature(name="age", dtype=ValueType.INT64)],
            >>>     input=FileSource(path="file.parquet", event_timestamp_column="timestamp"),
            >>>     ttl=timedelta(days=1)
            >>> )
            >>> fs.apply([customer_entity, customer_feature_view])
        """

        self._tele.log("apply")

        # TODO: Add locking
        # TODO: Optimize by only making a single call (read/write)

        views_to_update = []
        for ob in objects:
            if isinstance(ob, FeatureView):
                self._registry.apply_feature_view(ob, project=self.config.project)
                views_to_update.append(ob)
            elif isinstance(ob, Entity):
                self._registry.apply_entity(ob, project=self.config.project)
            else:
                raise ValueError(
                    f"Unknown object type ({type(ob)}) provided as part of apply() call"
                )
        self._get_provider().update_infra(
            project=self.config.project,
            tables_to_delete=[],
            tables_to_keep=views_to_update,
            partial=True,
        )

    def get_historical_features(
        self, entity_df: Union[pd.DataFrame, str], feature_refs: List[str],
    ) -> RetrievalJob:
        """Enrich an entity dataframe with historical feature values for either training or batch scoring.

        This method joins historical feature data from one or more feature views to an entity dataframe by using a time
        travel join.

        Each feature view is joined to the entity dataframe using all entities configured for the respective feature
        view. All configured entities must be available in the entity dataframe. Therefore, the entity dataframe must
        contain all entities found in all feature views, but the individual feature views can have different entities.

        Time travel is based on the configured TTL for each feature view. A shorter TTL will limit the
        amount of scanning that will be done in order to find feature data for a specific entity key. Setting a short
        TTL may result in null values being returned.

        Args:
            entity_df (Union[pd.DataFrame, str]): An entity dataframe is a collection of rows containing all entity
                columns (e.g., customer_id, driver_id) on which features need to be joined, as well as a event_timestamp
                column used to ensure point-in-time correctness. Either a Pandas DataFrame can be provided or a string
                SQL query. The query must be of a format supported by the configured offline store (e.g., BigQuery)
            feature_refs: A list of features that should be retrieved from the offline store. Feature references are of
                the format "feature_view:feature", e.g., "customer_fv:daily_transactions".

        Returns:
            RetrievalJob which can be used to materialize the results.

        Examples:
            Retrieve historical features using a BigQuery SQL entity dataframe
            >>> from feast.feature_store import FeatureStore
            >>>
            >>> fs = FeatureStore(config=RepoConfig(provider="gcp"))
            >>> retrieval_job = fs.get_historical_features(
            >>>     entity_df="SELECT event_timestamp, order_id, customer_id from gcp_project.my_ds.customer_orders",
            >>>     feature_refs=["customer:age", "customer:avg_orders_1d", "customer:avg_orders_7d"]
            >>> )
            >>> feature_data = job.to_df()
            >>> model.fit(feature_data) # insert your modeling framework here.
        """
        self._tele.log("get_historical_features")

        all_feature_views = self._registry.list_feature_views(
            project=self.config.project
        )
        feature_views = _get_requested_feature_views(feature_refs, all_feature_views)
        provider = self._get_provider()
        job = provider.get_historical_features(
            self.config,
            feature_views,
            feature_refs,
            entity_df,
            self._registry,
            self.project,
        )
        return job

    def materialize_incremental(
        self, end_date: datetime, feature_views: Optional[List[str]] = None,
    ) -> None:
        """
        Materialize incremental new data from the offline store into the online store.

        This method loads incremental new feature data up to the specified end time from either
        the specified feature views, or all feature views if none are specified,
        into the online store where it is available for online serving. The start time of
        the interval materialized is either the most recent end time of a prior materialization or
        (now - ttl) if no such prior materialization exists.

        Args:
            end_date (datetime): End date for time range of data to materialize into the online store
            feature_views (List[str]): Optional list of feature view names. If selected, will only run
                materialization for the specified feature views.

        Examples:
            Materialize all features into the online store up to 5 minutes ago.
            >>> from datetime import datetime, timedelta
            >>> from feast.feature_store import FeatureStore
            >>>
            >>> fs = FeatureStore(config=RepoConfig(provider="gcp", registry="gs://my-fs/", project="my_fs_proj"))
            >>> fs.materialize_incremental(end_date=datetime.utcnow() - timedelta(minutes=5))
        """
        self._tele.log("materialize_incremental")

        feature_views_to_materialize = []
        if feature_views is None:
            feature_views_to_materialize = self._registry.list_feature_views(
                self.config.project
            )
        else:
            for name in feature_views:
                feature_view = self._registry.get_feature_view(
                    name, self.config.project
                )
                feature_views_to_materialize.append(feature_view)

        # TODO paging large loads
        for feature_view in feature_views_to_materialize:
            start_date = feature_view.most_recent_end_time
            if start_date is None:
                if feature_view.ttl is None:
                    raise Exception(
                        f"No start time found for feature view {feature_view.name}. materialize_incremental() requires either a ttl to be set or for materialize() to have been run at least once."
                    )
                start_date = datetime.utcnow() - feature_view.ttl
            provider = self._get_provider()
            provider.materialize_single_feature_view(
                feature_view, start_date, end_date, self._registry, self.project
            )

    def materialize(
        self,
        start_date: datetime,
        end_date: datetime,
        feature_views: Optional[List[str]] = None,
    ) -> None:
        """
        Materialize data from the offline store into the online store.

        This method loads feature data in the specified interval from either
        the specified feature views, or all feature views if none are specified,
        into the online store where it is available for online serving.

        Args:
            start_date (datetime): Start date for time range of data to materialize into the online store
            end_date (datetime): End date for time range of data to materialize into the online store
            feature_views (List[str]): Optional list of feature view names. If selected, will only run
                materialization for the specified feature views.

        Examples:
            Materialize all features into the online store over the interval
            from 3 hours ago to 10 minutes ago.
            >>> from datetime import datetime, timedelta
            >>> from feast.feature_store import FeatureStore
            >>>
            >>> fs = FeatureStore(config=RepoConfig(provider="gcp"))
            >>> fs.materialize(
            >>>   start_date=datetime.utcnow() - timedelta(hours=3), end_date=datetime.utcnow() - timedelta(minutes=10)
            >>> )
        """
        self._tele.log("materialize")

        feature_views_to_materialize = []
        if feature_views is None:
            feature_views_to_materialize = self._registry.list_feature_views(
                self.config.project
            )
        else:
            for name in feature_views:
                feature_view = self._registry.get_feature_view(
                    name, self.config.project
                )
                feature_views_to_materialize.append(feature_view)

        # TODO paging large loads
        for feature_view in feature_views_to_materialize:
            provider = self._get_provider()
            provider.materialize_single_feature_view(
                feature_view, start_date, end_date, self._registry, self.project
            )

    def get_online_features(
        self, feature_refs: List[str], entity_rows: List[Dict[str, Any]],
    ) -> OnlineResponse:
        """
        Retrieves the latest online feature data.

        Note: This method will download the full feature registry the first time it is run. If you are using a
        remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
        duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has
        passed), then a new registry will be downloaded synchronously by this method. This download may
        introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
        refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to
        infinity (cache forever).

        Args:
            feature_refs: List of feature references that will be returned for each entity.
                Each feature reference should have the following format:
                "feature_table:feature" where "feature_table" & "feature" refer to
                the feature and feature table names respectively.
                Only the feature name is required.
            entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
        Returns:
            OnlineResponse containing the feature data in records.
        Examples:
            >>> from feast import FeatureStore
            >>>
            >>> store = FeatureStore(repo_path="...")
            >>> feature_refs = ["sales:daily_transactions"]
            >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}]
            >>>
            >>> online_response = store.get_online_features(
            >>>     feature_refs, entity_rows, project="my_project")
            >>> online_response_dict = online_response.to_dict()
            >>> print(online_response_dict)
            {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]}
        """
        self._tele.log("get_online_features")

        provider = self._get_provider()
        entities = self.list_entities(allow_cache=True)
        entity_name_to_join_key_map = {}
        for entity in entities:
            entity_name_to_join_key_map[entity.name] = entity.join_key

        join_key_rows = []
        for row in entity_rows:
            join_key_row = {}
            for entity_name, entity_value in row.items():
                try:
                    join_key = entity_name_to_join_key_map[entity_name]
                except KeyError:
                    raise Exception(
                        f"Entity {entity_name} does not exist in project {self.project}"
                    )
                join_key_row[join_key] = entity_value
            join_key_rows.append(join_key_row)

        entity_row_proto_list = _infer_online_entity_rows(join_key_rows)

        union_of_entity_keys = []
        result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

        for entity_row_proto in entity_row_proto_list:
            union_of_entity_keys.append(_entity_row_to_key(entity_row_proto))
            result_rows.append(_entity_row_to_field_values(entity_row_proto))

        all_feature_views = self._registry.list_feature_views(
            project=self.config.project, allow_cache=True
        )

        grouped_refs = _group_refs(feature_refs, all_feature_views)
        for table, requested_features in grouped_refs:
            entity_keys = _get_table_entity_keys(
                table, union_of_entity_keys, entity_name_to_join_key_map
            )
            read_rows = provider.online_read(
                project=self.project, table=table, entity_keys=entity_keys,
            )
            for row_idx, read_row in enumerate(read_rows):
                row_ts, feature_data = read_row
                result_row = result_rows[row_idx]

                if feature_data is None:
                    for feature_name in requested_features:
                        feature_ref = f"{table.name}__{feature_name}"
                        result_row.statuses[
                            feature_ref
                        ] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND
                else:
                    for feature_name in feature_data:
                        feature_ref = f"{table.name}__{feature_name}"
                        if feature_name in requested_features:
                            result_row.fields[feature_ref].CopyFrom(
                                feature_data[feature_name]
                            )
                            result_row.statuses[
                                feature_ref
                            ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

        return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows))
Example #25
0
def test_historical_features_from_parquet_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    with TemporaryDirectory() as temp_dir:
        driver_df = create_driver_hourly_stats_df(driver_entities, start_date,
                                                  end_date)
        driver_source = stage_driver_hourly_stats_parquet_source(
            temp_dir, driver_df)
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)
        customer_df = create_customer_daily_profile_df(customer_entities,
                                                       start_date, end_date)
        customer_source = stage_customer_daily_profile_parquet_source(
            temp_dir, customer_df)
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)
        driver = Entity(name="driver",
                        value_type=ValueType.INT64,
                        description="")
        customer = Entity(name="customer",
                          value_type=ValueType.INT64,
                          description="")

        store = FeatureStore(config=RepoConfig(
            metadata_store=os.path.join(temp_dir, "metadata.db"),
            project="default",
            provider="local",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                os.path.join(temp_dir, "online_store.db"), )),
        ))

        store.apply([driver, customer, driver_fv, customer_fv])

        job = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df = job.to_df()
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )
        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
        )
def test_historical_features_from_bigquery_sources(
    provider_type, infer_event_timestamp_col
):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date
        )
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date
        )
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(customer_source)

        driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="default",
                    provider="local",
                    online_store=SqliteOnlineStoreConfig(
                        path=os.path.join(temp_dir, "online_store.db"),
                    ),
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(
                        type="bigquery", dataset="foo"
                    ),
                )
            )
        else:
            raise Exception("Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (
            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
            if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns
            else "e_ts"
        )
        expected_df = get_expected_training_df(
            customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        if provider_type == "gcp_custom_offline_config":
            # Make sure that custom dataset name is being used from the offline_store config
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            # If the custom dataset name isn't provided in the config, use default `feast` name
            assertpy.assert_that(job_from_df.query).contains("feast.entity_df")

        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )
def test_historical_features_from_parquet_sources(infer_event_timestamp_col):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    with TemporaryDirectory() as temp_dir:
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date
        )
        driver_source = stage_driver_hourly_stats_parquet_source(temp_dir, driver_df)
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date
        )
        customer_source = stage_customer_daily_profile_parquet_source(
            temp_dir, customer_df
        )
        customer_fv = create_customer_daily_profile_feature_view(customer_source)
        driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(
                    path=os.path.join(temp_dir, "online_store.db")
                ),
            )
        )

        store.apply([driver, customer, driver_fv, customer_fv])

        job = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df = job.to_df()
        event_timestamp = (
            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
            if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns
            else "e_ts"
        )
        expected_df = get_expected_training_df(
            customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp,
        )
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
        )
Example #28
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6
def test_historical_features_from_bigquery_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = f"test_hist_retrieval_{int(time.time())}"

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver", value_type=ValueType.INT64)
        customer = Entity(name="customer", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db"), )),
        ))
        store.apply([driver, customer, driver_fv, customer_fv])

        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )
Example #30
0
def apply_total(repo_config: RepoConfig, repo_path: Path,
                skip_source_validation: bool):
    from colorama import Fore, Style

    os.chdir(repo_path)
    registry_config = repo_config.get_registry_config()
    project = repo_config.project
    if not is_valid_name(project):
        print(
            f"{project} is not valid. Project name should only have "
            f"alphanumerical values and underscores but not start with an underscore."
        )
        sys.exit(1)
    registry = Registry(
        registry_path=registry_config.path,
        repo_path=repo_path,
        cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
    )
    registry._initialize_registry()
    sys.dont_write_bytecode = True
    repo = parse_repo(repo_path)
    _validate_feature_views(repo.feature_views)
    data_sources = [t.batch_source for t in repo.feature_views]

    if not skip_source_validation:
        # Make sure the data source used by this feature view is supported by Feast
        for data_source in data_sources:
            data_source.validate(repo_config)

    # Make inferences
    update_entities_with_inferred_types_from_feature_views(
        repo.entities, repo.feature_views, repo_config)
    update_data_sources_with_inferred_event_timestamp_col(
        data_sources, repo_config)
    for view in repo.feature_views:
        view.infer_features_from_batch_source(repo_config)

    repo_table_names = set(t.name for t in repo.feature_tables)

    for t in repo.feature_views:
        repo_table_names.add(t.name)

    tables_to_delete = []
    for registry_table in registry.list_feature_tables(project=project):
        if registry_table.name not in repo_table_names:
            tables_to_delete.append(registry_table)

    views_to_delete = []
    for registry_view in registry.list_feature_views(project=project):
        if registry_view.name not in repo_table_names:
            views_to_delete.append(registry_view)

    sys.dont_write_bytecode = False
    for entity in repo.entities:
        registry.apply_entity(entity, project=project, commit=False)
        click.echo(
            f"Registered entity {Style.BRIGHT + Fore.GREEN}{entity.name}{Style.RESET_ALL}"
        )

    # Delete tables that should not exist
    for registry_table in tables_to_delete:
        registry.delete_feature_table(registry_table.name,
                                      project=project,
                                      commit=False)
        click.echo(
            f"Deleted feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL} from registry"
        )

    # Create tables that should
    for table in repo.feature_tables:
        registry.apply_feature_table(table, project, commit=False)
        click.echo(
            f"Registered feature table {Style.BRIGHT + Fore.GREEN}{table.name}{Style.RESET_ALL}"
        )

    # Delete views that should not exist
    for registry_view in views_to_delete:
        registry.delete_feature_view(registry_view.name,
                                     project=project,
                                     commit=False)
        click.echo(
            f"Deleted feature view {Style.BRIGHT + Fore.GREEN}{registry_view.name}{Style.RESET_ALL} from registry"
        )

    # Create views that should exist
    for view in repo.feature_views:
        registry.apply_feature_view(view, project, commit=False)
        click.echo(
            f"Registered feature view {Style.BRIGHT + Fore.GREEN}{view.name}{Style.RESET_ALL}"
        )
    registry.commit()

    apply_feature_services(registry, project, repo)

    infra_provider = get_provider(repo_config, repo_path)

    all_to_delete: List[Union[FeatureTable, FeatureView]] = []
    all_to_delete.extend(tables_to_delete)
    all_to_delete.extend(views_to_delete)

    all_to_keep: List[Union[FeatureTable, FeatureView]] = []
    all_to_keep.extend(repo.feature_tables)
    all_to_keep.extend(repo.feature_views)

    entities_to_delete: List[Entity] = []
    repo_entities_names = set([e.name for e in repo.entities])
    for registry_entity in registry.list_entities(project=project):
        if registry_entity.name not in repo_entities_names:
            entities_to_delete.append(registry_entity)

    entities_to_keep: List[Entity] = repo.entities

    for name in [view.name for view in repo.feature_tables
                 ] + [table.name for table in repo.feature_views]:
        click.echo(
            f"Deploying infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}"
        )
    for name in [view.name for view in views_to_delete
                 ] + [table.name for table in tables_to_delete]:
        click.echo(
            f"Removing infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}"
        )

    infra_provider.update_infra(
        project,
        tables_to_delete=all_to_delete,
        tables_to_keep=all_to_keep,
        entities_to_delete=entities_to_delete,
        entities_to_keep=entities_to_keep,
        partial=False,
    )