def test_to_make_sure_splitter_and_sampler_methods_are_optional(
    test_cases_for_sql_data_connector_sqlite_execution_engine,
):
    execution_engine = test_cases_for_sql_data_connector_sqlite_execution_engine

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            {
                "table_name": "table_partitioned_by_date_column__A",
                "batch_identifiers": {},
                "sampling_method": "_sample_using_mod",
                "sampling_kwargs": {
                    "column_name": "id",
                    "mod": 10,
                    "value": 8,
                },
            }
        )
    )
    execution_engine.load_batch_data("__", batch_data)
    validator = Validator(execution_engine)
    assert len(validator.head(fetch_all=True)) == 12

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            {
                "table_name": "table_partitioned_by_date_column__A",
                "batch_identifiers": {},
            }
        )
    )
    execution_engine.load_batch_data("__", batch_data)
    validator = Validator(execution_engine)
    assert len(validator.head(fetch_all=True)) == 120

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            {
                "table_name": "table_partitioned_by_date_column__A",
                "batch_identifiers": {},
                "splitter_method": "_split_on_whole_table",
                "splitter_kwargs": {},
            }
        )
    )

    execution_engine.load_batch_data("__", batch_data)
    validator = Validator(execution_engine)
    assert len(validator.head(fetch_all=True)) == 120
def ge_validator_sqlalchemy() -> Validator:
    validator = Validator(
        execution_engine=SqlAlchemyExecutionEngine(
            connection_string="postgresql://localhost:5432/test"),
        batches=[
            Batch(
                data=None,
                batch_request=BatchRequest(
                    datasource_name="my_postgresql_datasource",
                    data_connector_name="whole_table",
                    data_asset_name="foo2",
                ),
                batch_definition=BatchDefinition(
                    datasource_name="my_postgresql_datasource",
                    data_connector_name="whole_table",
                    data_asset_name="foo2",
                    batch_identifiers=IDDict(),
                ),
                batch_spec=SqlAlchemyDatasourceBatchSpec({
                    "data_asset_name": "foo2",
                    "table_name": "foo2",
                    "batch_identifiers": {},
                    "schema_name": "public",
                    "type": "table",
                }),
            )
        ],
    )
    return validator
Exemple #3
0
    def build_batch_spec(
        self, batch_definition: BatchDefinition
    ) -> SqlAlchemyDatasourceBatchSpec:
        """
        Build BatchSpec from batch_definition by calling DataConnector's build_batch_spec function.

        Args:
            batch_definition (BatchDefinition): to be used to build batch_spec

        Returns:
            BatchSpec built from batch_definition
        """
        batch_spec: BatchSpec = super().build_batch_spec(
            batch_definition=batch_definition
        )

        data_asset_name: str = batch_definition.data_asset_name
        if (
            data_asset_name in self.data_assets
            and self.data_assets[data_asset_name].get("batch_spec_passthrough")
            and isinstance(
                self.data_assets[data_asset_name].get("batch_spec_passthrough"), dict
            )
        ):
            batch_spec.update(
                self.data_assets[data_asset_name]["batch_spec_passthrough"]
            )

        return SqlAlchemyDatasourceBatchSpec(batch_spec)
    def build_batch_spec(
            self, batch_definition: BatchDefinition
    ) -> SqlAlchemyDatasourceBatchSpec:
        """
        Build BatchSpec from batch_definition by calling DataConnector's build_batch_spec function.

        Args:
            batch_definition (BatchDefinition): to be used to build batch_spec

        Returns:
            BatchSpec built from batch_definition
        """

        data_asset_name: str = batch_definition.data_asset_name
        if (data_asset_name in self.assets
                and self.assets[data_asset_name].get("batch_spec_passthrough")
                and isinstance(
                    self.assets[data_asset_name].get("batch_spec_passthrough"),
                    dict)):
            # batch_spec_passthrough from data_asset
            batch_spec_passthrough = deepcopy(
                self.assets[data_asset_name]["batch_spec_passthrough"])
            batch_definition_batch_spec_passthrough = (deepcopy(
                batch_definition.batch_spec_passthrough) or {})
            # batch_spec_passthrough from Batch Definition supercedes batch_spec_passthrough from data_asset
            batch_spec_passthrough.update(
                batch_definition_batch_spec_passthrough)
            batch_definition.batch_spec_passthrough = batch_spec_passthrough

        batch_spec: BatchSpec = super().build_batch_spec(
            batch_definition=batch_definition)

        return SqlAlchemyDatasourceBatchSpec(batch_spec)
def test_sampling_method__limit(
    test_cases_for_sql_data_connector_sqlite_execution_engine,
):
    execution_engine = test_cases_for_sql_data_connector_sqlite_execution_engine

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            {
                "table_name": "table_partitioned_by_date_column__A",
                "batch_identifiers": {},
                "splitter_method": "_split_on_whole_table",
                "splitter_kwargs": {},
                "sampling_method": "_sample_using_limit",
                "sampling_kwargs": {"n": 20},
            }
        )
    )

    batch = Batch(data=batch_data)

    validator = Validator(execution_engine, batches=[batch])
    assert len(validator.head(fetch_all=True)) == 20

    assert not validator.expect_column_values_to_be_in_set(
        "date", value_set=["2020-01-02"]
    ).success
def test_sqlite_split(
    test_case, sa, in_memory_sqlite_taxi_ten_trips_per_month_execution_engine
):
    """What does this test and why?
    splitters should work with sqlite.
    """

    engine: SqlAlchemyExecutionEngine = (
        in_memory_sqlite_taxi_ten_trips_per_month_execution_engine
    )

    batch_spec: SqlAlchemyDatasourceBatchSpec = SqlAlchemyDatasourceBatchSpec(
        table_name="test",
        schema_name="main",
        splitter_method=test_case.splitter_method_name,
        splitter_kwargs=test_case.splitter_kwargs,
        batch_identifiers={"pickup_datetime": test_case.expected_pickup_datetimes[0]},
    )
    batch_data: SqlAlchemyBatchData = engine.get_batch_data(batch_spec=batch_spec)

    # Right number of rows?
    num_rows: int = batch_data.execution_engine.engine.execute(
        sa.select([sa.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == test_case.num_expected_rows_in_first_batch_definition
Exemple #7
0
def test_instantiation_via_url_and_retrieve_data_with_other_dialect(sa):
    """Ensure that we can still retrieve data when the dialect is not recognized."""

    # 1. Create engine with sqlite db
    db_file = file_relative_path(
        __file__,
        os.path.join("..", "test_sets",
                     "test_cases_for_sql_data_connector.db"),
    )
    my_execution_engine = SqlAlchemyExecutionEngine(url="sqlite:///" + db_file)
    assert my_execution_engine.connection_string is None
    assert my_execution_engine.credentials is None
    assert my_execution_engine.url[
        -36:] == "test_cases_for_sql_data_connector.db"

    # 2. Change dialect to one not listed in GESqlDialect
    my_execution_engine.engine.dialect.name = "other_dialect"

    # 3. Get data
    num_rows_in_sample: int = 10
    batch_data, _ = my_execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            table_name="table_partitioned_by_date_column__A",
            sampling_method="_sample_using_limit",
            sampling_kwargs={"n": num_rows_in_sample},
        ))

    # 4. Assert dialect and data are as expected

    assert batch_data.dialect == GESqlDialect.OTHER

    my_execution_engine.load_batch_data("__", batch_data)
    validator = Validator(my_execution_engine)
    assert len(validator.head(fetch_all=True)) == num_rows_in_sample
def test_instantiation_via_connection_string(sa, test_db_connection_string):
    my_execution_engine = SqlAlchemyExecutionEngine(
        connection_string=test_db_connection_string)
    assert my_execution_engine.connection_string == test_db_connection_string
    assert my_execution_engine.credentials == None
    assert my_execution_engine.url == None

    my_execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            table_name="table_1",
            schema_name="main",
            sampling_method="_sample_using_limit",
            sampling_kwargs={"n": 5},
        ))
def test_instantiation_via_url(sa):
    db_file = file_relative_path(
        __file__,
        os.path.join("..", "test_sets", "test_cases_for_sql_data_connector.db"),
    )
    my_execution_engine = SqlAlchemyExecutionEngine(url="sqlite:///" + db_file)
    assert my_execution_engine.connection_string is None
    assert my_execution_engine.credentials is None
    assert my_execution_engine.url[-36:] == "test_cases_for_sql_data_connector.db"

    my_execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            table_name="table_partitioned_by_date_column__A",
            sampling_method="_sample_using_limit",
            sampling_kwargs={"n": 5},
        )
    )
def test_sampling_method__random(
    test_cases_for_sql_data_connector_sqlite_execution_engine, ):
    execution_engine = test_cases_for_sql_data_connector_sqlite_execution_engine

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec({
            "table_name": "table_partitioned_by_date_column__A",
            "batch_identifiers": {},
            "splitter_method": "_split_on_whole_table",
            "splitter_kwargs": {},
            "sampling_method": "_sample_using_random",
            "sampling_kwargs": {
                "p": 1.0
            },
        }))

    # random.seed() is no good here: the random number generator is in the database, not python
    # assert len(batch_data.head(fetch_all=True)) == 63
    pass
def test_sqlite_split_and_sample_using_limit(
        sa, in_memory_sqlite_taxi_ten_trips_per_month_execution_engine):
    """What does this test and why?
    splitters and samplers should work together in sqlite.
    """

    engine: SqlAlchemyExecutionEngine = (
        in_memory_sqlite_taxi_ten_trips_per_month_execution_engine)

    n: int = 3
    batch_spec: SqlAlchemyDatasourceBatchSpec = SqlAlchemyDatasourceBatchSpec(
        table_name="test",
        schema_name="main",
        sampling_method="sample_using_limit",
        sampling_kwargs={"n": n},
        splitter_method="split_on_year",
        splitter_kwargs={"column_name": "pickup_datetime"},
        batch_identifiers={"pickup_datetime": "2018"},
    )
    batch_data: SqlAlchemyBatchData = engine.get_batch_data(
        batch_spec=batch_spec)

    # Right number of rows?
    num_rows: int = batch_data.execution_engine.engine.execute(
        sa.select([sa.func.count()
                   ]).select_from(batch_data.selectable)).scalar()
    assert num_rows == n

    # Right rows?
    rows: sa.Row = batch_data.execution_engine.engine.execute(
        sa.select([sa.text("*")
                   ]).select_from(batch_data.selectable)).fetchall()

    row_dates: List[datetime.datetime] = [
        parse(row["pickup_datetime"]) for row in rows
    ]
    for row_date in row_dates:
        assert row_date.month == 1
        assert row_date.year == 2018
Exemple #12
0
def test_sqlite_sample_using_limit(sa):

    csv_path: str = file_relative_path(
        os.path.dirname(os.path.dirname(__file__)),
        os.path.join(
            "test_sets",
            "taxi_yellow_tripdata_samples",
            "ten_trips_from_each_month",
            "yellow_tripdata_sample_10_trips_from_each_month.csv",
        ),
    )
    df: pd.DataFrame = pd.read_csv(csv_path)
    engine: SqlAlchemyExecutionEngine = build_sa_engine(df, sa)

    n: int = 10
    batch_spec: SqlAlchemyDatasourceBatchSpec = SqlAlchemyDatasourceBatchSpec(
        table_name="test",
        schema_name="main",
        sampling_method="sample_using_limit",
        sampling_kwargs={"n": n},
    )
    batch_data: SqlAlchemyBatchData = engine.get_batch_data(batch_spec=batch_spec)

    # Right number of rows?
    num_rows: int = batch_data.execution_engine.engine.execute(
        sa.select([sa.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == n

    # Right rows?
    rows: sa.Row = batch_data.execution_engine.engine.execute(
        sa.select([sa.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    row_dates: List[datetime.datetime] = [parse(row["pickup_datetime"]) for row in rows]
    for row_date in row_dates:
        assert row_date.month == 1
        assert row_date.year == 2018
def test_sampling_method__a_list(
    sampler_method_name_prefix,
    test_cases_for_sql_data_connector_sqlite_execution_engine,
):
    execution_engine = test_cases_for_sql_data_connector_sqlite_execution_engine

    batch_data, batch_markers = execution_engine.get_batch_data_and_markers(
        batch_spec=SqlAlchemyDatasourceBatchSpec(
            {
                "table_name": "table_partitioned_by_date_column__A",
                "batch_identifiers": {},
                "splitter_method": "_split_on_whole_table",
                "splitter_kwargs": {},
                "sampling_method": f"{sampler_method_name_prefix}sample_using_a_list",
                "sampling_kwargs": {
                    "column_name": "id",
                    "value_list": [10, 20, 30, 40],
                },
            }
        )
    )
    execution_engine.load_batch_data("__", batch_data)
    validator = Validator(execution_engine)
    assert len(validator.head(fetch_all=True)) == 4
Exemple #14
0
def test_sample_using_random(sqlite_view_engine, test_df):
    my_execution_engine: SqlAlchemyExecutionEngine = SqlAlchemyExecutionEngine(
        engine=sqlite_view_engine
    )

    p: float
    batch_spec: SqlAlchemyDatasourceBatchSpec
    batch_data: SqlAlchemyBatchData
    num_rows: int
    rows_0: List[tuple]
    rows_1: List[tuple]

    # First, make sure that degenerative case never passes.

    test_df_0: pd.DataFrame = test_df.iloc[:1]
    test_df_0.to_sql("test_table_0", con=my_execution_engine.engine)

    p = 1.0
    batch_spec = SqlAlchemyDatasourceBatchSpec(
        table_name="test_table_0",
        schema_name="main",
        sampling_method="_sample_using_random",
        sampling_kwargs={"p": p},
    )

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_0.shape[0])

    rows_0: List[tuple] = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_0.shape[0])

    rows_1: List[tuple] = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    assert len(rows_0) == len(rows_1) == 1

    assert rows_0 == rows_1

    # Second, verify that realistic case always returns different random sample of rows.

    test_df_1: pd.DataFrame = test_df
    test_df_1.to_sql("test_table_1", con=my_execution_engine.engine)

    p = 2.0e-1
    batch_spec = SqlAlchemyDatasourceBatchSpec(
        table_name="test_table_1",
        schema_name="main",
        sampling_method="_sample_using_random",
        sampling_kwargs={"p": p},
    )

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_1.shape[0])

    rows_0 = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    batch_data = my_execution_engine.get_batch_data(batch_spec=batch_spec)
    num_rows = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.func.count()]).select_from(batch_data.selectable)
    ).scalar()
    assert num_rows == round(p * test_df_1.shape[0])

    rows_1 = batch_data.execution_engine.engine.execute(
        sqlalchemy.select([sqlalchemy.text("*")]).select_from(batch_data.selectable)
    ).fetchall()

    assert len(rows_0) == len(rows_1)

    assert not (rows_0 == rows_1)
def test_instantiation_with_and_without_temp_table(sqlite_view_engine, sa):
    print(get_sqlite_temp_table_names(sqlite_view_engine))
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1
    assert get_sqlite_temp_table_names(sqlite_view_engine) == {
        "test_temp_view"
    }

    execution_engine: SqlAlchemyExecutionEngine = SqlAlchemyExecutionEngine(
        engine=sqlite_view_engine)
    # When the SqlAlchemyBatchData object is based on a table, a new temp table is NOT created, even if create_temp_table=True
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        table_name="test_table",
        create_temp_table=True,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1

    selectable = sa.select("*").select_from(sa.text("main.test_table"))

    # If create_temp_table=False, a new temp table should NOT be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        create_temp_table=False,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 1

    # If create_temp_table=True, a new temp table should be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        create_temp_table=True,
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 2

    # If create_temp_table=True, a new temp table should be created
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        # create_temp_table defaults to True
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 3

    # testing whether schema is supported
    selectable = sa.select("*").select_from(
        sa.table(name="test_table", schema="main"))
    SqlAlchemyBatchData(
        execution_engine=execution_engine,
        selectable=selectable,
        # create_temp_table defaults to True
    )
    assert len(get_sqlite_temp_table_names(sqlite_view_engine)) == 4

    # test schema with execution engine
    # TODO : Will20210222 Add tests for specifying schema with non-sqlite backend that actually supports new schema creation
    my_batch_spec = SqlAlchemyDatasourceBatchSpec(
        **{
            "table_name": "test_table",
            "batch_identifiers": {},
            "schema_name": "main",
        })
    res = execution_engine.get_batch_data_and_markers(batch_spec=my_batch_spec)
    assert len(res) == 2
def test_sqlite_split(
    taxi_test_cases: TaxiSplittingTestCasesBase,
    sa,
):
    """What does this test and why?
    splitters should work with sqlite.
    """
    engine: SqlAlchemyExecutionEngine = build_sa_engine(
        taxi_test_cases.test_df, sa)

    test_cases: List[TaxiSplittingTestCase] = taxi_test_cases.test_cases()
    test_case: TaxiSplittingTestCase
    batch_spec: SqlAlchemyDatasourceBatchSpec
    for test_case in test_cases:
        if test_case.table_domain_test_case:
            batch_spec = SqlAlchemyDatasourceBatchSpec(
                table_name="test",
                schema_name="main",
                splitter_method=test_case.splitter_method_name,
                splitter_kwargs=test_case.splitter_kwargs,
                batch_identifiers={},
            )
        else:
            if taxi_test_cases.test_column_name:
                batch_spec = SqlAlchemyDatasourceBatchSpec(
                    table_name="test",
                    schema_name="main",
                    splitter_method=test_case.splitter_method_name,
                    splitter_kwargs=test_case.splitter_kwargs,
                    batch_identifiers={
                        taxi_test_cases.test_column_name:
                        test_case.expected_column_values[0]
                    },
                )
            elif taxi_test_cases.test_column_names:
                column_name: str
                batch_spec = SqlAlchemyDatasourceBatchSpec(
                    table_name="test",
                    schema_name="main",
                    splitter_method=test_case.splitter_method_name,
                    splitter_kwargs=test_case.splitter_kwargs,
                    batch_identifiers={
                        column_name:
                        test_case.expected_column_values[0][column_name]
                        for column_name in taxi_test_cases.test_column_names
                    },
                )
            else:
                raise ValueError(
                    "Missing test_column_names or test_column_names attribute."
                )

        batch_data: SqlAlchemyBatchData = engine.get_batch_data(
            batch_spec=batch_spec)

        # Right number of rows?
        num_rows: int = batch_data.execution_engine.engine.execute(
            sa.select([sa.func.count()
                       ]).select_from(batch_data.selectable)).scalar()
        # noinspection PyUnresolvedReferences
        assert num_rows == test_case.num_expected_rows_in_first_batch_definition