예제 #1
0
def _upload_entity_df(
    entity_df: Union[pd.DataFrame, str],
    snowflake_conn: SnowflakeConnection,
    config: RepoConfig,
    table_name: str,
) -> None:

    if isinstance(entity_df, pd.DataFrame):
        # Write the data from the DataFrame to the table
        write_pandas(
            snowflake_conn,
            entity_df,
            table_name,
            auto_create_table=True,
            create_temp_table=True,
        )

        return None
    elif isinstance(entity_df, str):
        # If the entity_df is a string (SQL query), create a Snowflake table out of it,
        query = f'CREATE TEMPORARY TABLE "{table_name}" AS ({entity_df})'
        execute_snowflake_statement(snowflake_conn, query)

        return None
    else:
        raise InvalidEntityType(type(entity_df))
예제 #2
0
def _get_entity_df_event_timestamp_range(
    entity_df: Union[pd.DataFrame, str],
    entity_df_event_timestamp_col: str,
    snowflake_conn: SnowflakeConnection,
) -> Tuple[datetime, datetime]:
    if isinstance(entity_df, pd.DataFrame):
        entity_df_event_timestamp = entity_df.loc[:,
                                                  entity_df_event_timestamp_col].infer_objects(
                                                  )
        if pd.api.types.is_string_dtype(entity_df_event_timestamp):
            entity_df_event_timestamp = pd.to_datetime(
                entity_df_event_timestamp, utc=True)
        entity_df_event_timestamp_range = (
            entity_df_event_timestamp.min().to_pydatetime(),
            entity_df_event_timestamp.max().to_pydatetime(),
        )
    elif isinstance(entity_df, str):
        # If the entity_df is a string (SQL query), determine range
        # from table
        query = f'SELECT MIN("{entity_df_event_timestamp_col}") AS "min_value", MAX("{entity_df_event_timestamp_col}") AS "max_value" FROM ({entity_df})'
        results = execute_snowflake_statement(snowflake_conn, query).fetchall()

        entity_df_event_timestamp_range = cast(Tuple[datetime, datetime],
                                               results[0])
    else:
        raise InvalidEntityType(type(entity_df))

    return entity_df_event_timestamp_range
예제 #3
0
    def to_snowflake(self, table_name: str) -> None:
        """Save dataset as a new Snowflake table"""
        if self.on_demand_feature_views is not None:
            transformed_df = self.to_df()

            write_pandas(self.snowflake_conn,
                         transformed_df,
                         table_name,
                         auto_create_table=True)

            return None

        with self._query_generator() as query:
            query = f'CREATE TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'

            execute_snowflake_statement(self.snowflake_conn, query)
예제 #4
0
    def _to_arrow_internal(self) -> pa.Table:
        with self._query_generator() as query:

            pa_table = execute_snowflake_statement(self.snowflake_conn,
                                                   query).fetch_arrow_all()

            if pa_table:

                return pa_table
            else:
                empty_result = execute_snowflake_statement(
                    self.snowflake_conn, query)

                return pa.Table.from_pandas(
                    pd.DataFrame(
                        columns=[md.name for md in empty_result.description]))
예제 #5
0
    def _to_df_internal(self) -> pd.DataFrame:
        with self._query_generator() as query:

            df = execute_snowflake_statement(self.snowflake_conn,
                                             query).fetch_pandas_all()

        return df
예제 #6
0
    def get_table_column_names_and_types(
            self, config: RepoConfig) -> Iterable[Tuple[str, str]]:
        """
        Returns a mapping of column names to types for this snowflake source.

        Args:
            config: A RepoConfig describing the feature repo
        """

        from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
        from feast.infra.utils.snowflake_utils import (
            execute_snowflake_statement,
            get_snowflake_conn,
        )

        assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)

        snowflake_conn = get_snowflake_conn(config.offline_store)

        if self.database and self.table:
            query = f'SELECT * FROM "{self.database}"."{self.schema}"."{self.table}" LIMIT 1'
        elif self.table:
            query = f'SELECT * FROM "{self.table}" LIMIT 1'
        else:
            query = f"SELECT * FROM ({self.query}) LIMIT 1"

        result = execute_snowflake_statement(snowflake_conn,
                                             query).fetch_pandas_all()

        if not result.empty:
            metadata = result.dtypes.apply(str)
            return list(zip(metadata.index, metadata))
        else:
            raise ValueError("The following source:\n" + query +
                             "\n ... is empty")
예제 #7
0
    def to_arrow_chunks(self,
                        arrow_options: Optional[Dict] = None
                        ) -> Optional[List]:
        with self._query_generator() as query:

            arrow_batches = execute_snowflake_statement(
                self.snowflake_conn, query).get_result_batches()

        return arrow_batches
예제 #8
0
def _get_entity_schema(
    entity_df: Union[pd.DataFrame, str],
    snowflake_conn: SnowflakeConnection,
    config: RepoConfig,
) -> Dict[str, np.dtype]:

    if isinstance(entity_df, pd.DataFrame):

        return dict(zip(entity_df.columns, entity_df.dtypes))

    else:

        query = f"SELECT * FROM ({entity_df}) LIMIT 1"
        limited_entity_df = execute_snowflake_statement(
            snowflake_conn, query).fetch_pandas_all()

        return dict(zip(limited_entity_df.columns, limited_entity_df.dtypes))