Ejemplo n.º 1
0
    def build_batch_spec(self, batch_definition: BatchDefinition) -> BatchSpec:
        """
        Builds batch_spec from batch_definition by generating batch_spec params and adding any pass_through params

        Args:
            batch_definition (BatchDefinition): required batch_definition parameter for retrieval
        Returns:
            BatchSpec object built from BatchDefinition

        """
        batch_spec_params: dict = (
            self._generate_batch_spec_parameters_from_batch_definition(
                batch_definition=batch_definition
            )
        )
        # batch_spec_passthrough via Data Connector config
        batch_spec_passthrough: dict = deepcopy(self.batch_spec_passthrough)

        # batch_spec_passthrough from batch_definition supercedes batch_spec_passthrough from Data Connector config
        if isinstance(batch_definition.batch_spec_passthrough, dict):
            batch_spec_passthrough.update(batch_definition.batch_spec_passthrough)

        batch_spec_params.update(batch_spec_passthrough)
        batch_spec: BatchSpec = BatchSpec(**batch_spec_params)
        return batch_spec
Ejemplo n.º 2
0
def batch_fixture() -> Batch:
    """
    Fixture for Batch object that contains data, BatchRequest, BatchDefinition
    as well as BatchSpec and BatchMarkers. To be used in unittesting.
    """
    df: pd.DataFrame = pd.DataFrame(
        {"a": [1, 5, 22, 3, 5, 10], "b": [1, 2, 3, 4, 5, 6]}
    )
    batch_request: BatchRequest = BatchRequest(
        datasource_name="my_datasource",
        data_connector_name="my_data_connector",
        data_asset_name="my_data_asset_name",
    )
    batch_definition: BatchDefinition = BatchDefinition(
        datasource_name="my_datasource",
        data_connector_name="my_data_connector",
        data_asset_name="my_data_asset_name",
        batch_identifiers=IDDict({"id": "A"}),
    )
    batch_spec: BatchSpec = BatchSpec(path="/some/path/some.file")
    batch_markers: BatchMarkers = BatchMarkers(ge_load_time="FAKE_LOAD_TIME")
    batch: Batch = Batch(
        data=df,
        batch_request=batch_request,
        batch_definition=batch_definition,
        batch_spec=batch_spec,
        batch_markers=batch_markers,
    )
    return batch
Ejemplo n.º 3
0
 def build_batch_spec(self, batch_definition: BatchDefinition) -> BatchSpec:
     batch_spec_params: dict = self._generate_batch_spec_parameters_from_batch_definition(
         batch_definition=batch_definition)
     batch_spec_passthrough: dict = batch_definition.batch_spec_passthrough
     if isinstance(batch_spec_passthrough, dict):
         batch_spec_params.update(batch_spec_passthrough)
     batch_spec: BatchSpec = BatchSpec(**batch_spec_params)
     return batch_spec
    def sample_using_random(
        self,
        execution_engine: "SqlAlchemyExecutionEngine",  # noqa: F821
        batch_spec: BatchSpec,
        where_clause: Optional[Selectable] = None,
    ) -> Selectable:
        """Sample using random data with configuration provided via the batch_spec.

        Note: where_clause needs to be included at this stage since we use the where clause
        to determine the total number of rows to use in determining the rows returned in the
        sample fraction.

        Args:
            execution_engine: Engine used to connect to the database.
            batch_spec: Batch specification describing the batch of interest.
            where_clause: Optional clause used in WHERE clause. Typically generated by a splitter.

        Returns:
            Sqlalchemy selectable.
        """

        # TODO: AJB 20220429 WARNING THIS METHOD IS NOT COVERED BY TESTS

        table_name: str = batch_spec["table_name"]

        num_rows: int = execution_engine.engine.execute(
            sa.select([sa.func.count()]).select_from(
                sa.table(table_name,
                         schema=batch_spec.get(
                             "schema_name",
                             None))).where(where_clause)).scalar()
        p: float = batch_spec["sampling_kwargs"]["p"] or 1.0
        sample_size: int = round(p * num_rows)
        return (sa.select("*").select_from(
            sa.table(table_name, schema=batch_spec.get(
                "schema_name", None))).where(where_clause).order_by(
                    sa.func.random()).limit(sample_size))
Ejemplo n.º 5
0
    def verify_batch_spec_sampling_kwargs_exists(self, batch_spec: BatchSpec) -> None:
        """Verify that sampling_kwargs key exists in batch_spec or raise error.

        Args:
            batch_spec: Can contain sampling_kwargs.

        Returns:
            None

        Raises:
            SamplerError
        """
        if batch_spec.get("sampling_kwargs") is None:
            raise ge_exceptions.SamplerError(
                "Please make sure to provide sampling_kwargs in addition to your sampling_method."
            )
Ejemplo n.º 6
0
    def __init__(
        self,
        data,
        batch_request: BatchRequest = None,
        batch_definition: BatchDefinition = None,
        batch_spec: BatchSpec = None,
        batch_markers: BatchMarkers = None,
        # The remaining parameters are for backward compatibility.
        data_context=None,
        datasource_name=None,
        batch_parameters=None,
        batch_kwargs=None,
    ):
        self._data = data
        if batch_request is None:
            batch_request = dict()
        self._batch_request = batch_request
        if batch_definition is None:
            batch_definition = IDDict()
        self._batch_definition = batch_definition
        if batch_spec is None:
            batch_spec = BatchSpec()
        self._batch_spec = batch_spec

        if batch_markers is None:
            batch_markers = BatchMarkers(
                {
                    "ge_load_time": datetime.datetime.now(
                        datetime.timezone.utc
                    ).strftime("%Y%m%dT%H%M%S.%fZ")
                }
            )
        self._batch_markers = batch_markers

        # The remaining parameters are for backward compatibility.
        self._data_context = data_context
        self._datasource_name = datasource_name
        self._batch_parameters = batch_parameters
        self._batch_kwargs = batch_kwargs or BatchKwargs()
Ejemplo n.º 7
0
def test_sample_using_limit_builds_correct_query_where_clause_none(
    dialect_name: GESqlDialect, dialect_name_to_sql_statement, sa
):
    """What does this test and why?

    split_on_limit should build the appropriate query based on input parameters.
    This tests dialects that differ from the standard dialect, not each dialect exhaustively.
    """

    # 1. Setup
    class MockSqlAlchemyExecutionEngine:
        def __init__(self, dialect_name: GESqlDialect):
            self._dialect_name = dialect_name
            self._connection_string = self.dialect_name_to_connection_string(
                dialect_name
            )

        DIALECT_TO_CONNECTION_STRING_STUB: dict = {
            GESqlDialect.POSTGRESQL: "postgresql://",
            GESqlDialect.MYSQL: "mysql+pymysql://",
            GESqlDialect.ORACLE: "oracle+cx_oracle://",
            GESqlDialect.MSSQL: "mssql+pyodbc://",
            GESqlDialect.SQLITE: "sqlite:///",
            GESqlDialect.BIGQUERY: "bigquery://",
            GESqlDialect.SNOWFLAKE: "snowflake://",
            GESqlDialect.REDSHIFT: "redshift+psycopg2://",
            GESqlDialect.AWSATHENA: f"awsathena+rest://@athena.us-east-1.amazonaws.com/some_test_db?s3_staging_dir=s3://some-s3-path/",
            GESqlDialect.DREMIO: "dremio://",
            GESqlDialect.TERADATASQL: "teradatasql://",
            GESqlDialect.TRINO: "trino://",
            GESqlDialect.HIVE: "hive://",
        }

        @property
        def dialect_name(self) -> str:
            return self._dialect_name.value

        def dialect_name_to_connection_string(self, dialect_name: GESqlDialect) -> str:
            return self.DIALECT_TO_CONNECTION_STRING_STUB.get(dialect_name)

        _BIGQUERY_MODULE_NAME = "sqlalchemy_bigquery"

        @property
        def dialect(self) -> sa.engine.Dialect:
            # TODO: AJB 20220512 move this dialect retrieval to a separate class from the SqlAlchemyExecutionEngine
            #  and then use it here.
            dialect_name: GESqlDialect = self._dialect_name
            if dialect_name == GESqlDialect.ORACLE:
                # noinspection PyUnresolvedReferences
                return import_library_module(
                    module_name="sqlalchemy.dialects.oracle"
                ).dialect()
            elif dialect_name == GESqlDialect.SNOWFLAKE:
                # noinspection PyUnresolvedReferences
                return import_library_module(
                    module_name="snowflake.sqlalchemy.snowdialect"
                ).dialect()
            elif dialect_name == GESqlDialect.DREMIO:
                # WARNING: Dremio Support is experimental, functionality is not fully under test
                # noinspection PyUnresolvedReferences
                return import_library_module(
                    module_name="sqlalchemy_dremio.pyodbc"
                ).dialect()
            # NOTE: AJB 20220512 Redshift dialect is not yet fully supported.
            # The below throws an `AttributeError: type object 'RedshiftDialect_psycopg2' has no attribute 'positional'`
            # elif dialect_name == "redshift":
            #     return import_library_module(
            #         module_name="sqlalchemy_redshift.dialect"
            #     ).RedshiftDialect
            elif dialect_name == GESqlDialect.BIGQUERY:
                # noinspection PyUnresolvedReferences
                return import_library_module(
                    module_name=self._BIGQUERY_MODULE_NAME
                ).dialect()
            elif dialect_name == GESqlDialect.TERADATASQL:
                # WARNING: Teradata Support is experimental, functionality is not fully under test
                # noinspection PyUnresolvedReferences
                return import_library_module(
                    module_name="teradatasqlalchemy.dialect"
                ).dialect()
            else:
                return sa.create_engine(self._connection_string).dialect

    mock_execution_engine: MockSqlAlchemyExecutionEngine = (
        MockSqlAlchemyExecutionEngine(dialect_name=dialect_name)
    )

    data_sampler: SqlAlchemyDataSampler = SqlAlchemyDataSampler()

    # 2. Create query using sampler
    table_name: str = "test_table"
    batch_spec: BatchSpec = BatchSpec(
        table_name=table_name,
        schema_name="test_schema_name",
        sampling_method="sample_using_limit",
        sampling_kwargs={"n": 10},
    )
    query = data_sampler.sample_using_limit(
        execution_engine=mock_execution_engine, batch_spec=batch_spec, where_clause=None
    )

    if not isinstance(query, str):
        query_str: str = clean_query_for_comparison(
            str(
                query.compile(
                    dialect=mock_execution_engine.dialect,
                    compile_kwargs={"literal_binds": True},
                )
            )
        )
    else:
        query_str: str = clean_query_for_comparison(query)

    expected: str = clean_query_for_comparison(
        dialect_name_to_sql_statement(dialect_name)
    )

    assert query_str == expected
    def sample_using_limit(
        self,
        execution_engine: "SqlAlchemyExecutionEngine",  # noqa: F821
        batch_spec: BatchSpec,
        where_clause: Optional[Selectable] = None,
    ) -> Union[str, BinaryExpression, BooleanClauseList]:
        """Sample using a limit with configuration provided via the batch_spec.

        Note: where_clause needs to be included at this stage since SqlAlchemy's semantics
        for LIMIT are different than normal WHERE clauses.

        Also this requires an engine to find the dialect since certain databases require
        different handling.

        Args:
            execution_engine: Engine used to connect to the database.
            batch_spec: Batch specification describing the batch of interest.
            where_clause: Optional clause used in WHERE clause. Typically generated by a splitter.

        Returns:
            A query as a string or sqlalchemy object.
        """

        # Split clause should be permissive of all values if not supplied.
        if where_clause is None:
            if execution_engine.dialect_name == "sqlite":
                where_clause = sa.text("1 = 1")
            else:
                where_clause = sa.true()

        table_name: str = batch_spec["table_name"]

        # SQLalchemy's semantics for LIMIT are different than normal WHERE clauses,
        # so the business logic for building the query needs to be different.
        dialect_name: str = execution_engine.dialect_name
        if dialect_name == GESqlDialect.ORACLE.value:
            # TODO: AJB 20220429 WARNING THIS oracle dialect METHOD IS NOT COVERED BY TESTS
            # limit doesn't compile properly for oracle so we will append rownum to query string later
            raw_query: Selectable = (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get("schema_name",
                                               None))).where(where_clause))
            query: str = str(
                raw_query.compile(
                    dialect=execution_engine.dialect,
                    compile_kwargs={"literal_binds": True},
                ))
            query += "\nAND ROWNUM <= %d" % batch_spec["sampling_kwargs"]["n"]
            return query
        elif dialect_name == GESqlDialect.MSSQL.value:
            # Note that this code path exists because the limit parameter is not getting rendered
            # successfully in the resulting mssql query.
            selectable_query: Selectable = (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get(
                             "schema_name", None))).where(where_clause).limit(
                                 batch_spec["sampling_kwargs"]["n"]))
            string_of_query: str = str(
                selectable_query.compile(
                    dialect=execution_engine.dialect,
                    compile_kwargs={"literal_binds": True},
                ))
            n: Union[str, int] = batch_spec["sampling_kwargs"]["n"]
            self._validate_mssql_limit_param(n)
            # This string replacement is here because the limit parameter is not substituted during query.compile()
            string_of_query = string_of_query.replace("?", str(n))
            return string_of_query
        else:
            return (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get(
                             "schema_name", None))).where(where_clause).limit(
                                 batch_spec["sampling_kwargs"]["n"]))
Ejemplo n.º 9
0
    def sample_using_limit(
        self,
        execution_engine: "SqlAlchemyExecutionEngine",  # noqa: F821
        batch_spec: BatchSpec,
        where_clause: Optional[Selectable] = None,
    ) -> Union[str, BinaryExpression, BooleanClauseList]:
        """Sample using a limit with configuration provided via the batch_spec.

        Note: where_clause needs to be included at this stage since SqlAlchemy's semantics
        for LIMIT are different than normal WHERE clauses.

        Also this requires an engine to find the dialect since certain databases require
        different handling.

        Args:
            execution_engine: Engine used to connect to the database.
            batch_spec: Batch specification describing the batch of interest.
            where_clause: Optional clause used in WHERE clause. Typically generated by a splitter.

        Returns:
            A query as a string or sqlalchemy object.
        """

        # Split clause should be permissive of all values if not supplied.
        if not where_clause:
            where_clause = True

        table_name: str = batch_spec["table_name"]

        # SQLalchemy's semantics for LIMIT are different than normal WHERE clauses,
        # so the business logic for building the query needs to be different.
        dialect: str = execution_engine.engine.dialect.name.lower()
        if dialect == "oracle":
            # TODO: AJB 20220429 WARNING THIS oracle dialect METHOD IS NOT COVERED BY TESTS
            # limit doesn't compile properly for oracle so we will append rownum to query string later
            raw_query: Selectable = (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get("schema_name",
                                               None))).where(where_clause))
            query: str = str(
                raw_query.compile(execution_engine,
                                  compile_kwargs={"literal_binds": True}))
            query += "\nAND ROWNUM <= %d" % batch_spec["sampling_kwargs"]["n"]
            return query
        elif dialect == "mssql":
            # TODO: AJB 20220429 WARNING THIS mssql dialect METHOD IS NOT COVERED BY TESTS
            selectable_query: Selectable = (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get(
                             "schema_name", None))).where(where_clause).limit(
                                 batch_spec["sampling_kwargs"]["n"]))
            string_of_query: str = str(
                selectable_query.compile(
                    execution_engine.engine,
                    compile_kwargs={"literal_binds": True}))
            # TODO: AJB 20220504 REMOVE THIS HACK!
            # This hack is here because the limit parameter is not substituted during query.compile()
            n: Union[str, int] = batch_spec["sampling_kwargs"]["n"]
            if not isinstance(n, (str, int)):
                raise ge_exceptions.InvalidConfigError(
                    "Please specify your sampling kwargs 'n' parameter as a string or int."
                )
            if isinstance(n, str) and not n.isdigit():
                raise ge_exceptions.InvalidConfigError(
                    "If specifying your sampling kwargs 'n' parameter as a string please ensure it is "
                    "parseable as an integer.")
            string_of_query = string_of_query.replace("?", str(n))
            return string_of_query
        else:
            return (sa.select("*").select_from(
                sa.table(table_name,
                         schema=batch_spec.get(
                             "schema_name", None))).where(where_clause).limit(
                                 batch_spec["sampling_kwargs"]["n"]))