def __init__( self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None, ): self.repo_path = repo_path if repo_path is not None and config is not None: raise ValueError("You cannot specify both repo_path and config") if config is not None: self.config = config elif repo_path is not None: self.config = load_repo_config(Path(repo_path)) else: self.config = RepoConfig( registry="./registry.db", project="default", provider="local", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig(path="online_store.db") ), ) registry_config = self.config.get_registry_config() self._registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) self._tele = Telemetry()
def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: with tempfile.NamedTemporaryFile(suffix=".parquet") as f: df = create_dataset() f.close() df.to_parquet(f.name) file_source = FileSource( file_format=ParquetFormat(), file_url=f"file://{f.name}", event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", field_mapping={ "ts_1": "ts", "id": "driver_id" }, ) fv = get_feature_view(file_source) with tempfile.TemporaryDirectory( ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project= f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", provider="local", online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig( path=str(Path(data_dir_name) / "online_store.db"))), ) fs = FeatureStore(config=config) fs.apply([fv]) yield fs, fv
def prep_dynamodb_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: with tempfile.NamedTemporaryFile(suffix=".parquet") as f: df = create_dataset() f.close() df.to_parquet(f.name) file_source = FileSource( file_format=ParquetFormat(), file_url=f"file://{f.name}", event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", field_mapping={"ts_1": "ts", "id": "driver_id"}, ) fv = get_feature_view(file_source) e = Entity( name="driver", description="id for driver", join_key="driver_id", value_type=ValueType.INT32, ) with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", provider="aws", online_store=DynamoDBOnlineStoreConfig(region="us-west-2"), offline_store=FileOfflineStoreConfig(), ) fs = FeatureStore(config=config) fs.apply([fv, e]) yield fs, fv
def repo_config(): return RepoConfig( registry=REGISTRY, project=PROJECT, provider=PROVIDER, online_store=DynamoDBOnlineStoreConfig(region=REGION), offline_store=FileOfflineStoreConfig(), )
def registry_dump(repo_config: RepoConfig, repo_path: Path): """For debugging only: output contents of the metadata registry""" registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry(registry_config=registry_config, repo_path=repo_path) registry_dict = registry.to_dict(project=project) click.echo(json.dumps(registry_dict, indent=2, sort_keys=True))
def test_bigquery_query_to_datastore_correctness(self): # create dataset ts = pd.Timestamp.now(tz="UTC").round("ms") data = { "id": [1, 2, 1], "value": [0.1, 0.2, 0.3], "ts_1": [ts - timedelta(minutes=2), ts, ts], "created_ts": [ts, ts, ts], } df = pd.DataFrame.from_dict(data) # load dataset into BigQuery job_config = bigquery.LoadJobConfig() table_id = f"{self.gcp_project}.{self.bigquery_dataset}.query_correctness_{int(time.time())}" query = f"SELECT * FROM `{table_id}`" job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) job.result() # create FeatureView fv = FeatureView( name="test_bq_query_correctness", entities=["driver_id"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(minutes=5), input=BigQuerySource( event_timestamp_column="ts", created_timestamp_column="created_ts", field_mapping={ "ts_1": "ts", "id": "driver_id" }, date_partition_column="", query=query, ), ) config = RepoConfig( metadata_store="./metadata.db", project=f"test_bq_query_correctness_{int(time.time())}", provider="gcp", ) fs = FeatureStore(config=config) fs.apply([fv]) # run materialize() fs.materialize( [fv.name], datetime.utcnow() - timedelta(minutes=5), datetime.utcnow() - timedelta(minutes=0), ) # check result of materialize() response_dict = fs.get_online_features([f"{fv.name}:value"], [{ "driver_id": 1 }]).to_dict() assert abs(response_dict[f"{fv.name}:value"][0] - 0.3) < 1e-6
def feature_store_with_s3_registry(): return FeatureStore(config=RepoConfig( registry= f"s3://feast-integration-tests/registries/{int(time.time() * 1000)}/registry.db", project="default", provider="aws", online_store=DynamoDBOnlineStoreConfig(region="us-west-2"), offline_store=FileOfflineStoreConfig(), ))
def test_apply_remote_repo(): fd, registry_path = mkstemp() fd, online_store_path = mkstemp() return FeatureStore(config=RepoConfig( registry=registry_path, project="default", provider="local", online_store=SqliteOnlineStoreConfig(path=online_store_path), ))
def feature_store_with_local_registry(): fd, registry_path = mkstemp() fd, online_store_path = mkstemp() return FeatureStore(config=RepoConfig( registry=registry_path, project="default", provider="local", online_store=SqliteOnlineStoreConfig(path=online_store_path), ))
def registry_dump(repo_config: RepoConfig, repo_path: Path): """ For debugging only: output contents of the metadata registry """ registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry(registry_config=registry_config, repo_path=repo_path) for entity in registry.list_entities(project=project): print(entity) for feature_view in registry.list_feature_views(project=project): print(feature_view)
def feature_store_with_local_registry(self): fd, registry_path = mkstemp() fd, online_store_path = mkstemp() return FeatureStore(config=RepoConfig( metadata_store=registry_path, project="default", provider="local", online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig( path=online_store_path)), ))
def prep_bq_fs_and_fv( bq_source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]: client = bigquery.Client() gcp_project = client.project bigquery_dataset = "test_ingestion" dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}") client.create_dataset(dataset, exists_ok=True) dataset.default_table_expiration_ms = (1000 * 60 * 60 * 24 * 14 ) # 2 weeks in milliseconds client.update_dataset(dataset, ["default_table_expiration_ms"]) df = create_dataset() job_config = bigquery.LoadJobConfig() table_ref = f"{gcp_project}.{bigquery_dataset}.{bq_source_type}_correctness_{int(time.time_ns())}" query = f"SELECT * FROM `{table_ref}`" job = client.load_table_from_dataframe(df, table_ref, job_config=job_config) job.result() bigquery_source = BigQuerySource( table_ref=table_ref if bq_source_type == "table" else None, query=query if bq_source_type == "query" else None, event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", field_mapping={ "ts_1": "ts", "id": "driver_id" }, ) fv = driver_feature_view(bigquery_source) e = Entity( name="driver", description="id for driver", join_key="driver_id", value_type=ValueType.INT32, ) with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", provider="gcp", online_store=DatastoreOnlineStoreConfig( namespace="integration_test"), ) fs = FeatureStore(config=config) fs.apply([fv, e]) yield fs, fv fs.teardown()
def teardown(repo_config: RepoConfig, repo_path: Path): registry_config = repo_config.get_registry_config() registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) project = repo_config.project registry_tables: List[Union[FeatureTable, FeatureView]] = [] registry_tables.extend(registry.list_feature_tables(project=project)) registry_tables.extend(registry.list_feature_views(project=project)) infra_provider = get_provider(repo_config) infra_provider.teardown_infra(project, tables=registry_tables)
def registry_dump(repo_config: RepoConfig): """ For debugging only: output contents of the metadata registry """ registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) for entity in registry.list_entities(project=project): print(entity) for table in registry.list_feature_tables(project=project): print(table)
def benchmark_writes(): project_id = "test" + "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(10) ) with tempfile.TemporaryDirectory() as temp_dir: store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project=project_id, provider="gcp", ) ) # This is just to set data source to something, we're not reading from parquet source here. parquet_path = os.path.join(temp_dir, "data.parquet") driver = Entity(name="driver_id", value_type=ValueType.INT64) table = create_driver_hourly_stats_feature_view( create_driver_hourly_stats_source(parquet_path=parquet_path) ) store.apply([table, driver]) provider = store._get_provider() end_date = datetime.utcnow() start_date = end_date - timedelta(days=14) customers = list(range(100)) data = create_driver_hourly_stats_df(customers, start_date, end_date) # Show the data for reference print(data) proto_data = _convert_arrow_to_proto( pa.Table.from_pandas(data), table, ["driver_id"] ) # Write it with tqdm(total=len(proto_data)) as progress: provider.online_write_batch( project=store.project, table=table, data=proto_data, progress=progress.update, ) registry_tables = store.list_feature_views() registry_entities = store.list_entities() provider.teardown_infra( store.project, tables=registry_tables, entities=registry_entities )
def _build_feast_feature_store(self): os.environ["FEAST_S3_ENDPOINT_URL"] = aws.S3_ENDPOINT.get() os.environ["AWS_ACCESS_KEY_ID"] = aws.S3_ACCESS_KEY_ID.get() os.environ["AWS_SECRET_ACCESS_KEY"] = aws.S3_SECRET_ACCESS_KEY.get() config = RepoConfig( registry=f"s3://{self.config.s3_bucket}/{self.config.registry_path}", project=self.config.project, # Notice the use of a custom provider. provider="custom_provider.provider.FlyteCustomProvider", offline_store=FileOfflineStoreConfig(), online_store=SqliteOnlineStoreConfig(path=self.config.online_store_path), ) return FeastFeatureStore(config=config)
def registry_dump(repo_config: RepoConfig, repo_path: Path): """ For debugging only: output contents of the metadata registry """ from colorama import Fore, Style registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry(registry_config=registry_config, repo_path=repo_path) registry_dict = registry.to_dict(project=project) warning = ( "Warning: The registry-dump command is for debugging only and may contain " "breaking changes in the future. No guarantees are made on this interface." ) click.echo(f"{Style.BRIGHT}{Fore.YELLOW}{warning}{Style.RESET_ALL}") click.echo(json.dumps(registry_dict, indent=2))
def feature_store_with_gcs_registry(self): from google.cloud import storage storage_client = storage.Client() bucket_name = f"feast-registry-test-{int(time.time())}" bucket = storage_client.bucket(bucket_name) bucket = storage_client.create_bucket(bucket) bucket.add_lifecycle_delete_rule( age=14) # delete buckets automatically after 14 days bucket.patch() bucket.blob("metadata.db") return FeatureStore(config=RepoConfig( metadata_store=f"gs://{bucket_name}/metadata.db", project="default", provider="gcp", ))
def __init__( self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None, ): if repo_path is not None and config is not None: raise ValueError("You cannot specify both repo_path and config") if config is not None: self.config = config elif repo_path is not None: self.config = load_repo_config(Path(repo_path)) else: self.config = RepoConfig( metadata_store="./metadata.db", project="default", provider="local", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig("online_store.db")), )
def prep_redis_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: with tempfile.NamedTemporaryFile(suffix=".parquet") as f: df = create_dataset() f.close() df.to_parquet(f.name) file_source = FileSource( file_format=ParquetFormat(), path=f"file://{f.name}", event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", field_mapping={ "ts_1": "ts", "id": "driver_id" }, ) fv = driver_feature_view(file_source) e = Entity( name="driver", description="id for driver", join_key="driver_id", value_type=ValueType.INT32, ) project = f"test_redis_correctness_{str(uuid.uuid4()).replace('-', '')}" print(f"Using project: {project}") with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=project, provider="local", online_store=RedisOnlineStoreConfig( type="redis", redis_type=RedisType.redis, connection_string="localhost:6379,db=0", ), ) fs = FeatureStore(config=config) fs.apply([fv, e]) yield fs, fv fs.teardown()
def apply_total(repo_config: RepoConfig, repo_path: Path): from colorama import Fore, Style os.chdir(repo_path) sys.path.append("") registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry( registry_path=registry_config.path, repo_path=repo_path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) registry._initialize_registry() sys.dont_write_bytecode = True repo = parse_repo(repo_path) sys.dont_write_bytecode = False for entity in repo.entities: registry.apply_entity(entity, project=project) click.echo( f"Registered entity {Style.BRIGHT + Fore.GREEN}{entity.name}{Style.RESET_ALL}" ) repo_table_names = set(t.name for t in repo.feature_tables) for t in repo.feature_views: repo_table_names.add(t.name) tables_to_delete = [] for registry_table in registry.list_feature_tables(project=project): if registry_table.name not in repo_table_names: tables_to_delete.append(registry_table) views_to_delete = [] for registry_view in registry.list_feature_views(project=project): if registry_view.name not in repo_table_names: views_to_delete.append(registry_view) # Delete tables that should not exist for registry_table in tables_to_delete: registry.delete_feature_table(registry_table.name, project=project) click.echo( f"Deleted feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL} from registry" ) # Create tables that should for table in repo.feature_tables: registry.apply_feature_table(table, project) click.echo( f"Registered feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL}" ) # Delete views that should not exist for registry_view in views_to_delete: registry.delete_feature_view(registry_view.name, project=project) click.echo( f"Deleted feature view {Style.BRIGHT + Fore.GREEN}{registry_view.name}{Style.RESET_ALL} from registry" ) # Create views that should for view in repo.feature_views: registry.apply_feature_view(view, project) click.echo( f"Registered feature view {Style.BRIGHT + Fore.GREEN}{view.name}{Style.RESET_ALL}" ) infra_provider = get_provider(repo_config, repo_path) all_to_delete: List[Union[FeatureTable, FeatureView]] = [] all_to_delete.extend(tables_to_delete) all_to_delete.extend(views_to_delete) all_to_keep: List[Union[FeatureTable, FeatureView]] = [] all_to_keep.extend(repo.feature_tables) all_to_keep.extend(repo.feature_views) for name in [view.name for view in repo.feature_tables ] + [table.name for table in repo.feature_views]: click.echo( f"Deploying infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}" ) for name in [view.name for view in views_to_delete ] + [table.name for table in tables_to_delete]: click.echo( f"Removing infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}" ) infra_provider.update_infra( project, tables_to_delete=all_to_delete, tables_to_keep=all_to_keep, partial=False, )
def prep_redshift_fs_and_fv( source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]: client = aws_utils.get_redshift_data_client("us-west-2") s3 = aws_utils.get_s3_resource("us-west-2") df = create_dataset() table_name = f"test_ingestion_{source_type}_correctness_{int(time.time_ns())}_{random.randint(1000, 9999)}" offline_store = RedshiftOfflineStoreConfig( cluster_id="feast-integration-tests", region="us-west-2", user="******", database="feast", s3_staging_location= "s3://feast-integration-tests/redshift/tests/ingestion", iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role", ) aws_utils.upload_df_to_redshift( client, offline_store.cluster_id, offline_store.database, offline_store.user, s3, f"{offline_store.s3_staging_location}/copy/{table_name}.parquet", offline_store.iam_role, table_name, df, ) redshift_source = RedshiftSource( table=table_name if source_type == "table" else None, query=f"SELECT * FROM {table_name}" if source_type == "query" else None, event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", field_mapping={ "ts_1": "ts", "id": "driver_id" }, ) fv = driver_feature_view(redshift_source) e = Entity( name="driver", description="id for driver", join_key="driver_id", value_type=ValueType.INT32, ) with tempfile.TemporaryDirectory( ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", provider="local", online_store=SqliteOnlineStoreConfig( path=str(Path(data_dir_name) / "online_store.db")), offline_store=offline_store, ) fs = FeatureStore(config=config) fs.apply([fv, e]) yield fs, fv fs.teardown() # Clean up the uploaded Redshift table aws_utils.execute_redshift_statement( client, offline_store.cluster_id, offline_store.database, offline_store.user, f"DROP TABLE {table_name}", )
def apply_total(repo_config: RepoConfig, repo_path: Path): os.chdir(repo_path) sys.path.append("") registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) repo = parse_repo(repo_path) for entity in repo.entities: registry.apply_entity(entity, project=project) repo_table_names = set(t.name for t in repo.feature_tables) for t in repo.feature_views: repo_table_names.add(t.name) tables_to_delete = [] for registry_table in registry.list_feature_tables(project=project): if registry_table.name not in repo_table_names: tables_to_delete.append(registry_table) views_to_delete = [] for registry_view in registry.list_feature_views(project=project): if registry_view.name not in repo_table_names: views_to_delete.append(registry_view) # Delete tables that should not exist for registry_table in tables_to_delete: registry.delete_feature_table(registry_table.name, project=project) # Create tables that should for table in repo.feature_tables: registry.apply_feature_table(table, project) # Delete views that should not exist for registry_view in views_to_delete: registry.delete_feature_view(registry_view.name, project=project) # Create views that should for view in repo.feature_views: registry.apply_feature_view(view, project) infra_provider = get_provider(repo_config) all_to_delete: List[Union[FeatureTable, FeatureView]] = [] all_to_delete.extend(tables_to_delete) all_to_delete.extend(views_to_delete) all_to_keep: List[Union[FeatureTable, FeatureView]] = [] all_to_keep.extend(repo.feature_tables) all_to_keep.extend(repo.feature_views) infra_provider.update_infra( project, tables_to_delete=all_to_delete, tables_to_keep=all_to_keep, partial=False, ) print("Done!")
class FeatureStore: """ A FeatureStore object is used to define, create, and retrieve features. """ config: RepoConfig repo_path: Optional[str] _registry: Registry def __init__( self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None, ): self.repo_path = repo_path if repo_path is not None and config is not None: raise ValueError("You cannot specify both repo_path and config") if config is not None: self.config = config elif repo_path is not None: self.config = load_repo_config(Path(repo_path)) else: self.config = RepoConfig( registry="./registry.db", project="default", provider="local", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig(path="online_store.db") ), ) registry_config = self.config.get_registry_config() self._registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) self._tele = Telemetry() def version(self) -> str: """Returns the version of the current Feast SDK/CLI""" return get_version() @property def project(self) -> str: return self.config.project def _get_provider(self) -> Provider: return get_provider(self.config) def refresh_registry(self): """Fetches and caches a copy of the feature registry in memory. Explicitly calling this method allows for direct control of the state of the registry cache. Every time this method is called the complete registry state will be retrieved from the remote registry store backend (e.g., GCS, S3), and the cache timer will be reset. If refresh_registry() is run before get_online_features() is called, then get_online_feature() will use the cached registry instead of retrieving (and caching) the registry itself. Additionally, the TTL for the registry cache can be set to infinity (by setting it to 0), which means that refresh_registry() will become the only way to update the cached registry. If the TTL is set to a value greater than 0, then once the cache becomes stale (more time than the TTL has passed), a new cache will be downloaded synchronously, which may increase latencies if the triggering method is get_online_features() """ self._tele.log("refresh_registry") registry_config = self.config.get_registry_config() self._registry = Registry( registry_path=registry_config.path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) self._registry.refresh() def list_entities(self, allow_cache: bool = False) -> List[Entity]: """ Retrieve a list of entities from the registry Args: allow_cache (bool): Whether to allow returning entities from a cached registry Returns: List of entities """ self._tele.log("list_entities") return self._registry.list_entities(self.project, allow_cache=allow_cache) def list_feature_views(self) -> List[FeatureView]: """ Retrieve a list of feature views from the registry Returns: List of feature views """ self._tele.log("list_feature_views") return self._registry.list_feature_views(self.project) def get_entity(self, name: str) -> Entity: """ Retrieves an entity. Args: name: Name of entity Returns: Returns either the specified entity, or raises an exception if none is found """ self._tele.log("get_entity") return self._registry.get_entity(name, self.project) def get_feature_view(self, name: str) -> FeatureView: """ Retrieves a feature view. Args: name: Name of feature view Returns: Returns either the specified feature view, or raises an exception if none is found """ self._tele.log("get_feature_view") return self._registry.get_feature_view(name, self.project) def delete_feature_view(self, name: str): """ Deletes a feature view or raises an exception if not found. Args: name: Name of feature view """ self._tele.log("delete_feature_view") return self._registry.delete_feature_view(name, self.project) def apply(self, objects: List[Union[FeatureView, Entity]]): """Register objects to metadata store and update related infrastructure. The apply method registers one or more definitions (e.g., Entity, FeatureView) and registers or updates these objects in the Feast registry. Once the registry has been updated, the apply method will update related infrastructure (e.g., create tables in an online store) in order to reflect these new definitions. All operations are idempotent, meaning they can safely be rerun. Args: objects (List[Union[FeatureView, Entity]]): A list of FeatureView or Entity objects that should be registered Examples: Register a single Entity and FeatureView. >>> from feast.feature_store import FeatureStore >>> from feast import Entity, FeatureView, Feature, ValueType, FileSource >>> from datetime import timedelta >>> >>> fs = FeatureStore() >>> customer_entity = Entity(name="customer", value_type=ValueType.INT64, description="customer entity") >>> customer_feature_view = FeatureView( >>> name="customer_fv", >>> entities=["customer"], >>> features=[Feature(name="age", dtype=ValueType.INT64)], >>> input=FileSource(path="file.parquet", event_timestamp_column="timestamp"), >>> ttl=timedelta(days=1) >>> ) >>> fs.apply([customer_entity, customer_feature_view]) """ self._tele.log("apply") # TODO: Add locking # TODO: Optimize by only making a single call (read/write) views_to_update = [] for ob in objects: if isinstance(ob, FeatureView): self._registry.apply_feature_view(ob, project=self.config.project) views_to_update.append(ob) elif isinstance(ob, Entity): self._registry.apply_entity(ob, project=self.config.project) else: raise ValueError( f"Unknown object type ({type(ob)}) provided as part of apply() call" ) self._get_provider().update_infra( project=self.config.project, tables_to_delete=[], tables_to_keep=views_to_update, partial=True, ) def get_historical_features( self, entity_df: Union[pd.DataFrame, str], feature_refs: List[str], ) -> RetrievalJob: """Enrich an entity dataframe with historical feature values for either training or batch scoring. This method joins historical feature data from one or more feature views to an entity dataframe by using a time travel join. Each feature view is joined to the entity dataframe using all entities configured for the respective feature view. All configured entities must be available in the entity dataframe. Therefore, the entity dataframe must contain all entities found in all feature views, but the individual feature views can have different entities. Time travel is based on the configured TTL for each feature view. A shorter TTL will limit the amount of scanning that will be done in order to find feature data for a specific entity key. Setting a short TTL may result in null values being returned. Args: entity_df (Union[pd.DataFrame, str]): An entity dataframe is a collection of rows containing all entity columns (e.g., customer_id, driver_id) on which features need to be joined, as well as a event_timestamp column used to ensure point-in-time correctness. Either a Pandas DataFrame can be provided or a string SQL query. The query must be of a format supported by the configured offline store (e.g., BigQuery) feature_refs: A list of features that should be retrieved from the offline store. Feature references are of the format "feature_view:feature", e.g., "customer_fv:daily_transactions". Returns: RetrievalJob which can be used to materialize the results. Examples: Retrieve historical features using a BigQuery SQL entity dataframe >>> from feast.feature_store import FeatureStore >>> >>> fs = FeatureStore(config=RepoConfig(provider="gcp")) >>> retrieval_job = fs.get_historical_features( >>> entity_df="SELECT event_timestamp, order_id, customer_id from gcp_project.my_ds.customer_orders", >>> feature_refs=["customer:age", "customer:avg_orders_1d", "customer:avg_orders_7d"] >>> ) >>> feature_data = job.to_df() >>> model.fit(feature_data) # insert your modeling framework here. """ self._tele.log("get_historical_features") all_feature_views = self._registry.list_feature_views( project=self.config.project ) feature_views = _get_requested_feature_views(feature_refs, all_feature_views) provider = self._get_provider() job = provider.get_historical_features( self.config, feature_views, feature_refs, entity_df, self._registry, self.project, ) return job def materialize_incremental( self, end_date: datetime, feature_views: Optional[List[str]] = None, ) -> None: """ Materialize incremental new data from the offline store into the online store. This method loads incremental new feature data up to the specified end time from either the specified feature views, or all feature views if none are specified, into the online store where it is available for online serving. The start time of the interval materialized is either the most recent end time of a prior materialization or (now - ttl) if no such prior materialization exists. Args: end_date (datetime): End date for time range of data to materialize into the online store feature_views (List[str]): Optional list of feature view names. If selected, will only run materialization for the specified feature views. Examples: Materialize all features into the online store up to 5 minutes ago. >>> from datetime import datetime, timedelta >>> from feast.feature_store import FeatureStore >>> >>> fs = FeatureStore(config=RepoConfig(provider="gcp", registry="gs://my-fs/", project="my_fs_proj")) >>> fs.materialize_incremental(end_date=datetime.utcnow() - timedelta(minutes=5)) """ self._tele.log("materialize_incremental") feature_views_to_materialize = [] if feature_views is None: feature_views_to_materialize = self._registry.list_feature_views( self.config.project ) else: for name in feature_views: feature_view = self._registry.get_feature_view( name, self.config.project ) feature_views_to_materialize.append(feature_view) # TODO paging large loads for feature_view in feature_views_to_materialize: start_date = feature_view.most_recent_end_time if start_date is None: if feature_view.ttl is None: raise Exception( f"No start time found for feature view {feature_view.name}. materialize_incremental() requires either a ttl to be set or for materialize() to have been run at least once." ) start_date = datetime.utcnow() - feature_view.ttl provider = self._get_provider() provider.materialize_single_feature_view( feature_view, start_date, end_date, self._registry, self.project ) def materialize( self, start_date: datetime, end_date: datetime, feature_views: Optional[List[str]] = None, ) -> None: """ Materialize data from the offline store into the online store. This method loads feature data in the specified interval from either the specified feature views, or all feature views if none are specified, into the online store where it is available for online serving. Args: start_date (datetime): Start date for time range of data to materialize into the online store end_date (datetime): End date for time range of data to materialize into the online store feature_views (List[str]): Optional list of feature view names. If selected, will only run materialization for the specified feature views. Examples: Materialize all features into the online store over the interval from 3 hours ago to 10 minutes ago. >>> from datetime import datetime, timedelta >>> from feast.feature_store import FeatureStore >>> >>> fs = FeatureStore(config=RepoConfig(provider="gcp")) >>> fs.materialize( >>> start_date=datetime.utcnow() - timedelta(hours=3), end_date=datetime.utcnow() - timedelta(minutes=10) >>> ) """ self._tele.log("materialize") feature_views_to_materialize = [] if feature_views is None: feature_views_to_materialize = self._registry.list_feature_views( self.config.project ) else: for name in feature_views: feature_view = self._registry.get_feature_view( name, self.config.project ) feature_views_to_materialize.append(feature_view) # TODO paging large loads for feature_view in feature_views_to_materialize: provider = self._get_provider() provider.materialize_single_feature_view( feature_view, start_date, end_date, self._registry, self.project ) def get_online_features( self, feature_refs: List[str], entity_rows: List[Dict[str, Any]], ) -> OnlineResponse: """ Retrieves the latest online feature data. Note: This method will download the full feature registry the first time it is run. If you are using a remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has passed), then a new registry will be downloaded synchronously by this method. This download may introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to infinity (cache forever). Args: feature_refs: List of feature references that will be returned for each entity. Each feature reference should have the following format: "feature_table:feature" where "feature_table" & "feature" refer to the feature and feature table names respectively. Only the feature name is required. entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. Returns: OnlineResponse containing the feature data in records. Examples: >>> from feast import FeatureStore >>> >>> store = FeatureStore(repo_path="...") >>> feature_refs = ["sales:daily_transactions"] >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}] >>> >>> online_response = store.get_online_features( >>> feature_refs, entity_rows, project="my_project") >>> online_response_dict = online_response.to_dict() >>> print(online_response_dict) {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]} """ self._tele.log("get_online_features") provider = self._get_provider() entities = self.list_entities(allow_cache=True) entity_name_to_join_key_map = {} for entity in entities: entity_name_to_join_key_map[entity.name] = entity.join_key join_key_rows = [] for row in entity_rows: join_key_row = {} for entity_name, entity_value in row.items(): try: join_key = entity_name_to_join_key_map[entity_name] except KeyError: raise Exception( f"Entity {entity_name} does not exist in project {self.project}" ) join_key_row[join_key] = entity_value join_key_rows.append(join_key_row) entity_row_proto_list = _infer_online_entity_rows(join_key_rows) union_of_entity_keys = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] for entity_row_proto in entity_row_proto_list: union_of_entity_keys.append(_entity_row_to_key(entity_row_proto)) result_rows.append(_entity_row_to_field_values(entity_row_proto)) all_feature_views = self._registry.list_feature_views( project=self.config.project, allow_cache=True ) grouped_refs = _group_refs(feature_refs, all_feature_views) for table, requested_features in grouped_refs: entity_keys = _get_table_entity_keys( table, union_of_entity_keys, entity_name_to_join_key_map ) read_rows = provider.online_read( project=self.project, table=table, entity_keys=entity_keys, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row result_row = result_rows[row_idx] if feature_data is None: for feature_name in requested_features: feature_ref = f"{table.name}__{feature_name}" result_row.statuses[ feature_ref ] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND else: for feature_name in feature_data: feature_ref = f"{table.name}__{feature_name}" if feature_name in requested_features: result_row.fields[feature_ref].CopyFrom( feature_data[feature_name] ) result_row.statuses[ feature_ref ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows))
def test_historical_features_from_parquet_sources(): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date) with TemporaryDirectory() as temp_dir: driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) driver_source = stage_driver_hourly_stats_parquet_source( temp_dir, driver_df) driver_fv = create_driver_hourly_stats_feature_view(driver_source) customer_df = create_customer_daily_profile_df(customer_entities, start_date, end_date) customer_source = stage_customer_daily_profile_parquet_source( temp_dir, customer_df) customer_fv = create_customer_daily_profile_feature_view( customer_source) driver = Entity(name="driver", value_type=ValueType.INT64, description="") customer = Entity(name="customer", value_type=ValueType.INT64, description="") store = FeatureStore(config=RepoConfig( metadata_store=os.path.join(temp_dir, "metadata.db"), project="default", provider="local", online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig( os.path.join(temp_dir, "online_store.db"), )), )) store.apply([driver, customer, driver_fv, customer_fv]) job = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df = job.to_df() expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, ) assert_frame_equal( expected_df.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), actual_df.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), )
def test_historical_features_from_bigquery_sources( provider_type, infer_event_timestamp_col ): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date, infer_event_timestamp_col) # bigquery_dataset = "test_hist_retrieval_static" bigquery_dataset = ( f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}" ) with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir: gcp_project = bigquery.Client().project # Orders Query table_id = f"{bigquery_dataset}.orders" stage_orders_bigquery(orders_df, table_id) entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}" # Driver Feature View driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date ) driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly" stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id) driver_source = BigQuerySource( table_ref=driver_table_id, event_timestamp_column="datetime", created_timestamp_column="created", ) driver_fv = create_driver_hourly_stats_feature_view(driver_source) # Customer Feature View customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date ) customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile" stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id) customer_source = BigQuerySource( table_ref=customer_table_id, event_timestamp_column="datetime", created_timestamp_column="", ) customer_fv = create_customer_daily_profile_feature_view(customer_source) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) customer = Entity(name="customer_id", value_type=ValueType.INT64) if provider_type == "local": store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online_store.db"), ), offline_store=BigQueryOfflineStoreConfig(type="bigquery",), ) ) elif provider_type == "gcp": store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10) ), provider="gcp", offline_store=BigQueryOfflineStoreConfig(type="bigquery",), ) ) elif provider_type == "gcp_custom_offline_config": store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="".join( random.choices(string.ascii_uppercase + string.digits, k=10) ), provider="gcp", offline_store=BigQueryOfflineStoreConfig( type="bigquery", dataset="foo" ), ) ) else: raise Exception("Invalid provider used as part of test configuration") store.apply([driver, customer, driver_fv, customer_fv]) event_timestamp = ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns else "e_ts" ) expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, ) job_from_sql = store.get_historical_features( entity_df=entity_df_query, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df_from_sql_entities = job_from_sql.to_df() assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), actual_df_from_sql_entities.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), check_dtype=False, ) job_from_df = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) if provider_type == "gcp_custom_offline_config": # Make sure that custom dataset name is being used from the offline_store config assertpy.assert_that(job_from_df.query).contains("foo.entity_df") else: # If the custom dataset name isn't provided in the config, use default `feast` name assertpy.assert_that(job_from_df.query).contains("feast.entity_df") actual_df_from_df_entities = job_from_df.to_df() assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), actual_df_from_df_entities.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), check_dtype=False, )
def test_historical_features_from_parquet_sources(infer_event_timestamp_col): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date, infer_event_timestamp_col) with TemporaryDirectory() as temp_dir: driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date ) driver_source = stage_driver_hourly_stats_parquet_source(temp_dir, driver_df) driver_fv = create_driver_hourly_stats_feature_view(driver_source) customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date ) customer_source = stage_customer_daily_profile_parquet_source( temp_dir, customer_df ) customer_fv = create_customer_daily_profile_feature_view(customer_source) driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) customer = Entity(name="customer_id", value_type=ValueType.INT64) store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="local", online_store=SqliteOnlineStoreConfig( path=os.path.join(temp_dir, "online_store.db") ), ) ) store.apply([driver, customer, driver_fv, customer_fv]) job = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df = job.to_df() event_timestamp = ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns else "e_ts" ) expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, ) assert_frame_equal( expected_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), actual_df.sort_values( by=[event_timestamp, "order_id", "driver_id", "customer_id"] ).reset_index(drop=True), )
def test_bigquery_ingestion_correctness(self): # create dataset ts = pd.Timestamp.now(tz="UTC").round("ms") checked_value = ( random.random() ) # random value so test doesn't still work if no values written to online store data = { "id": [1, 2, 1], "value": [0.1, 0.2, checked_value], "ts_1": [ts - timedelta(minutes=2), ts, ts], "created_ts": [ts, ts, ts], } df = pd.DataFrame.from_dict(data) # load dataset into BigQuery job_config = bigquery.LoadJobConfig() table_id = ( f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}" ) job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) job.result() # create FeatureView fv = FeatureView( name="test_bq_correctness", entities=["driver_id"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(minutes=5), input=BigQuerySource( event_timestamp_column="ts", table_ref=table_id, created_timestamp_column="created_ts", field_mapping={ "ts_1": "ts", "id": "driver_id" }, date_partition_column="", ), ) config = RepoConfig( metadata_store="./metadata.db", project="default", provider="gcp", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig("online_store.db")), ) fs = FeatureStore(config=config) fs.apply([fv]) # run materialize() fs.materialize( ["test_bq_correctness"], datetime.utcnow() - timedelta(minutes=5), datetime.utcnow() - timedelta(minutes=0), ) # check result of materialize() entity_key = EntityKeyProto(entity_names=["driver_id"], entity_values=[ValueProto(int64_val=1)]) t, val = fs._get_provider().online_read("default", fv, entity_key) assert abs(val["value"].double_val - checked_value) < 1e-6
def test_historical_features_from_bigquery_sources(): start_date = datetime.now().replace(microsecond=0, second=0, minute=0) ( customer_entities, driver_entities, end_date, orders_df, start_date, ) = generate_entities(start_date) # bigquery_dataset = "test_hist_retrieval_static" bigquery_dataset = f"test_hist_retrieval_{int(time.time())}" with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir: gcp_project = bigquery.Client().project # Orders Query table_id = f"{bigquery_dataset}.orders" stage_orders_bigquery(orders_df, table_id) entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}" # Driver Feature View driver_df = driver_data.create_driver_hourly_stats_df( driver_entities, start_date, end_date) driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly" stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id) driver_source = BigQuerySource( table_ref=driver_table_id, event_timestamp_column="datetime", created_timestamp_column="created", ) driver_fv = create_driver_hourly_stats_feature_view(driver_source) # Customer Feature View customer_df = driver_data.create_customer_daily_profile_df( customer_entities, start_date, end_date) customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile" stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id) customer_source = BigQuerySource( table_ref=customer_table_id, event_timestamp_column="datetime", created_timestamp_column="created", ) customer_fv = create_customer_daily_profile_feature_view( customer_source) driver = Entity(name="driver", value_type=ValueType.INT64) customer = Entity(name="customer", value_type=ValueType.INT64) store = FeatureStore(config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project="default", provider="gcp", online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig( path=os.path.join(temp_dir, "online_store.db"), )), )) store.apply([driver, customer, driver_fv, customer_fv]) expected_df = get_expected_training_df( customer_df, customer_fv, driver_df, driver_fv, orders_df, ) job_from_sql = store.get_historical_features( entity_df=entity_df_query, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df_from_sql_entities = job_from_sql.to_df() assert_frame_equal( expected_df.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), actual_df_from_sql_entities.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), check_dtype=False, ) job_from_df = store.get_historical_features( entity_df=orders_df, feature_refs=[ "driver_stats:conv_rate", "driver_stats:avg_daily_trips", "customer_profile:current_balance", "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], ) actual_df_from_df_entities = job_from_df.to_df() assert_frame_equal( expected_df.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), actual_df_from_df_entities.sort_values(by=[ ENTITY_DF_EVENT_TIMESTAMP_COL, "order_id", "driver_id", "customer_id", ]).reset_index(drop=True), check_dtype=False, )
def apply_total(repo_config: RepoConfig, repo_path: Path, skip_source_validation: bool): from colorama import Fore, Style os.chdir(repo_path) registry_config = repo_config.get_registry_config() project = repo_config.project if not is_valid_name(project): print( f"{project} is not valid. Project name should only have " f"alphanumerical values and underscores but not start with an underscore." ) sys.exit(1) registry = Registry( registry_path=registry_config.path, repo_path=repo_path, cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds), ) registry._initialize_registry() sys.dont_write_bytecode = True repo = parse_repo(repo_path) _validate_feature_views(repo.feature_views) data_sources = [t.batch_source for t in repo.feature_views] if not skip_source_validation: # Make sure the data source used by this feature view is supported by Feast for data_source in data_sources: data_source.validate(repo_config) # Make inferences update_entities_with_inferred_types_from_feature_views( repo.entities, repo.feature_views, repo_config) update_data_sources_with_inferred_event_timestamp_col( data_sources, repo_config) for view in repo.feature_views: view.infer_features_from_batch_source(repo_config) repo_table_names = set(t.name for t in repo.feature_tables) for t in repo.feature_views: repo_table_names.add(t.name) tables_to_delete = [] for registry_table in registry.list_feature_tables(project=project): if registry_table.name not in repo_table_names: tables_to_delete.append(registry_table) views_to_delete = [] for registry_view in registry.list_feature_views(project=project): if registry_view.name not in repo_table_names: views_to_delete.append(registry_view) sys.dont_write_bytecode = False for entity in repo.entities: registry.apply_entity(entity, project=project, commit=False) click.echo( f"Registered entity {Style.BRIGHT + Fore.GREEN}{entity.name}{Style.RESET_ALL}" ) # Delete tables that should not exist for registry_table in tables_to_delete: registry.delete_feature_table(registry_table.name, project=project, commit=False) click.echo( f"Deleted feature table {Style.BRIGHT + Fore.GREEN}{registry_table.name}{Style.RESET_ALL} from registry" ) # Create tables that should for table in repo.feature_tables: registry.apply_feature_table(table, project, commit=False) click.echo( f"Registered feature table {Style.BRIGHT + Fore.GREEN}{table.name}{Style.RESET_ALL}" ) # Delete views that should not exist for registry_view in views_to_delete: registry.delete_feature_view(registry_view.name, project=project, commit=False) click.echo( f"Deleted feature view {Style.BRIGHT + Fore.GREEN}{registry_view.name}{Style.RESET_ALL} from registry" ) # Create views that should exist for view in repo.feature_views: registry.apply_feature_view(view, project, commit=False) click.echo( f"Registered feature view {Style.BRIGHT + Fore.GREEN}{view.name}{Style.RESET_ALL}" ) registry.commit() apply_feature_services(registry, project, repo) infra_provider = get_provider(repo_config, repo_path) all_to_delete: List[Union[FeatureTable, FeatureView]] = [] all_to_delete.extend(tables_to_delete) all_to_delete.extend(views_to_delete) all_to_keep: List[Union[FeatureTable, FeatureView]] = [] all_to_keep.extend(repo.feature_tables) all_to_keep.extend(repo.feature_views) entities_to_delete: List[Entity] = [] repo_entities_names = set([e.name for e in repo.entities]) for registry_entity in registry.list_entities(project=project): if registry_entity.name not in repo_entities_names: entities_to_delete.append(registry_entity) entities_to_keep: List[Entity] = repo.entities for name in [view.name for view in repo.feature_tables ] + [table.name for table in repo.feature_views]: click.echo( f"Deploying infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}" ) for name in [view.name for view in views_to_delete ] + [table.name for table in tables_to_delete]: click.echo( f"Removing infrastructure for {Style.BRIGHT + Fore.GREEN}{name}{Style.RESET_ALL}" ) infra_provider.update_infra( project, tables_to_delete=all_to_delete, tables_to_keep=all_to_keep, entities_to_delete=entities_to_delete, entities_to_keep=entities_to_keep, partial=False, )