예제 #1
0
def test_non_local_feature_repo() -> None:
    """
    Test running apply on a sample repo, and make sure the infra gets created.
    """
    runner = CliRunner()
    with tempfile.TemporaryDirectory() as repo_dir_name:

        # Construct an example repo in a temporary dir
        repo_path = Path(repo_dir_name)

        repo_config = repo_path / "feature_store.yaml"

        repo_config.write_text(
            dedent("""
        project: foo
        registry: data/registry.db
        provider: local
        online_store:
            path: data/online_store.db
        """))

        repo_example = repo_path / "example.py"
        repo_example.write_text(
            (Path(__file__).parent / "example_feature_repo_1.py").read_text())

        result = runner.run(["apply"], cwd=repo_path)
        assertpy.assert_that(result.returncode).is_equal_to(0)

        fs = FeatureStore(repo_path=str(repo_path))
        assertpy.assert_that(fs.list_feature_views()).is_length(3)

        result = runner.run(["teardown"], cwd=repo_path)
        assertpy.assert_that(result.returncode).is_equal_to(0)
예제 #2
0
파일: cli.py 프로젝트: qooba/feast
def feature_view_list(ctx: click.Context):
    """
    List all feature views
    """
    repo = ctx.obj["CHDIR"]
    cli_check_repo(repo)
    store = FeatureStore(repo_path=str(repo))
    table = []
    for feature_view in [
            *store.list_feature_views(),
            *store.list_request_feature_views(),
            *store.list_on_demand_feature_views(),
    ]:
        entities = set()
        if isinstance(feature_view, FeatureView):
            entities.update(feature_view.entities)
        elif isinstance(feature_view, OnDemandFeatureView):
            for backing_fv in feature_view.inputs.values():
                if isinstance(backing_fv, FeatureView):
                    entities.update(backing_fv.entities)
        table.append([
            feature_view.name,
            entities if len(entities) > 0 else "n/a",
            type(feature_view).__name__,
        ])

    from tabulate import tabulate

    print(
        tabulate(table, headers=["NAME", "ENTITIES", "TYPE"],
                 tablefmt="plain"))
예제 #3
0
def feature_view_list():
    """
    List all feature views
    """
    cli_check_repo(Path.cwd())
    store = FeatureStore(repo_path=str(Path.cwd()))
    table = []
    for feature_view in store.list_feature_views():
        table.append([feature_view.name, feature_view.entities])

    from tabulate import tabulate

    print(tabulate(table, headers=["NAME", "ENTITIES"], tablefmt="plain"))
예제 #4
0
def feature_view_list(ctx: click.Context):
    """
    List all feature views
    """
    repo = ctx.obj["CHDIR"]
    cli_check_repo(repo)
    store = FeatureStore(repo_path=str(repo))
    table = []
    for feature_view in store.list_feature_views():
        table.append([feature_view.name, feature_view.entities])

    from tabulate import tabulate

    print(tabulate(table, headers=["NAME", "ENTITIES"], tablefmt="plain"))
예제 #5
0
def benchmark_writes():
    project_id = "test" + "".join(
        random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project=project_id,
                provider="gcp",
            )
        )

        # This is just to set data source to something, we're not reading from parquet source here.
        parquet_path = os.path.join(temp_dir, "data.parquet")

        driver = Entity(name="driver_id", value_type=ValueType.INT64)
        table = create_driver_hourly_stats_feature_view(
            create_driver_hourly_stats_source(parquet_path=parquet_path)
        )
        store.apply([table, driver])

        provider = store._get_provider()

        end_date = datetime.utcnow()
        start_date = end_date - timedelta(days=14)
        customers = list(range(100))
        data = create_driver_hourly_stats_df(customers, start_date, end_date)

        # Show the data for reference
        print(data)
        proto_data = _convert_arrow_to_proto(
            pa.Table.from_pandas(data), table, ["driver_id"]
        )

        # Write it
        with tqdm(total=len(proto_data)) as progress:
            provider.online_write_batch(
                project=store.project,
                table=table,
                data=proto_data,
                progress=progress.update,
            )

        registry_tables = store.list_feature_views()
        registry_entities = store.list_entities()
        provider.teardown_infra(
            store.project, tables=registry_tables, entities=registry_entities
        )