Ejemplo n.º 1
0
def test_get_batch_with_split_on_multi_column_values(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_multi_column_values",
            splitter_kwargs={
                "column_names": ["y", "m", "d"],
                "partition_definition": {
                    "y": 2020,
                    "m": 1,
                    "d": 5,
                },
            },
        ))
    assert split_df.shape == (4, 10)
    assert (split_df.date == datetime.date(2020, 1, 5)).all()

    with pytest.raises(ValueError):
        split_df = PandasExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_df,
                splitter_method="_split_on_multi_column_values",
                splitter_kwargs={
                    "column_names": ["I", "dont", "exist"],
                    "partition_definition": {
                        "y": 2020,
                        "m": 1,
                        "d": 5,
                    },
                },
            ))
Ejemplo n.º 2
0
def test_get_batch_with_split_on_column_value(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_column_value",
            splitter_kwargs={
                "column_name": "batch_id",
                "partition_definition": {
                    "batch_id": 2
                },
            },
        ))
    assert split_df.shape == (12, 10)
    assert (split_df.batch_id == 2).all()

    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_column_value",
            splitter_kwargs={
                "column_name": "date",
                "partition_definition": {
                    "date": datetime.date(2020, 1, 30)
                },
            },
        ))
    assert (split_df).shape == (3, 10)
Ejemplo n.º 3
0
def test_get_batch_with_split_on_multi_column_values(test_sparkdf):
    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_multi_column_values",
            splitter_kwargs={
                "column_names": ["y", "m", "d"],
                "partition_definition": {
                    "y": 2020,
                    "m": 1,
                    "d": 5,
                },
            },
        ))
    assert split_df.count() == 4
    assert len(split_df.columns) == 10
    collected = split_df.collect()
    for val in collected:
        assert val.date == datetime.date(2020, 1, 5)

    with pytest.raises(ValueError):
        split_df = SparkDFExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_sparkdf,
                splitter_method="_split_on_multi_column_values",
                splitter_kwargs={
                    "column_names": ["I", "dont", "exist"],
                    "partition_definition": {
                        "y": 2020,
                        "m": 1,
                        "d": 5,
                    },
                },
            ))
Ejemplo n.º 4
0
def test_sample_using_md5(test_df):
    with pytest.raises(ge_exceptions.ExecutionEngineError):
        sampled_df = PandasExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_df,
                sampling_method="_sample_using_hash",
                sampling_kwargs={
                    "column_name": "date",
                    "hash_function_name": "I_am_not_valid",
                },
            ))

    sampled_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            sampling_method="_sample_using_hash",
            sampling_kwargs={
                "column_name": "date",
                "hash_function_name": "md5"
            },
        ))
    assert sampled_df.shape == (10, 10)
    assert sampled_df.date.isin([
        datetime.date(2020, 1, 15),
        datetime.date(2020, 1, 29),
    ]).all()
Ejemplo n.º 5
0
def test_get_batch_with_split_on_column_value(test_sparkdf):
    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_column_value",
            splitter_kwargs={
                "column_name": "batch_id",
                "partition_definition": {"batch_id": 2},
            },
        )
    )
    assert test_sparkdf.count() == 120
    assert len(test_sparkdf.columns) == 10
    collected = split_df.collect()
    for val in collected:
        assert val.batch_id == 2

    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_column_value",
            splitter_kwargs={
                "column_name": "date",
                "partition_definition": {"date": datetime.date(2020, 1, 30)},
            },
        )
    )
    assert split_df.count() == 3
    assert len(split_df.columns) == 10
Ejemplo n.º 6
0
def test_get_batch_with_split_on_hashed_column(test_df):
    with pytest.raises(ge_exceptions.ExecutionEngineError):
        split_df = PandasExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_df,
                splitter_method="_split_on_hashed_column",
                splitter_kwargs={
                    "column_name": "favorite_color",
                    "hash_digits": 1,
                    "partition_definition": {
                        "hash_value": "a",
                    },
                    "hash_function_name": "I_am_not_valid",
                },
            ))

    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_hashed_column",
            splitter_kwargs={
                "column_name": "favorite_color",
                "hash_digits": 1,
                "partition_definition": {
                    "hash_value": "a",
                },
                "hash_function_name": "sha256",
            },
        ))
    assert split_df.shape == (8, 10)
Ejemplo n.º 7
0
def test_get_batch_data(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_df, ))
    assert split_df.shape == (120, 10)

    # No dataset passed to RuntimeDataBatchSpec
    with pytest.raises(ge_exceptions.InvalidBatchSpecError):
        PandasExecutionEngine().get_batch_data(RuntimeDataBatchSpec())
Ejemplo n.º 8
0
def test_split_on_multi_column_values_and_sample_using_random(test_sparkdf):
    returned_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_multi_column_values",
            splitter_kwargs={
                "column_names": ["y", "m", "d"],
                "partition_definition": {
                    "y": 2020,
                    "m": 1,
                    "d": 5,
                },
            },
            sampling_method="_sample_using_random",
            sampling_kwargs={
                "p": 0.5,
            },
        )
    )

    # The test dataframe contains 10 columns and 120 rows.
    assert len(returned_df.columns) == 10
    # The number of returned rows corresponding to the value of "partition_definition" above is 4.
    assert 0 <= returned_df.count() <= 4
    # The sampling probability "p" used in "SparkDFExecutionEngine._sample_using_random()" is 0.5 (the equivalent of a
    # fair coin with the 50% chance of coming up as "heads").  Hence, on average we should get 50% of the rows, which is
    # 2; however, for such a small sample (of 4 rows), the number of rows returned by an individual run can deviate from
    # this average.  Still, in the majority of trials, the number of rows should not be fewer than 2 or greater than 3.
    # The assertion in the next line, supporting this reasoning, is commented out to insure zero failures.  Developers
    # are encouraged to uncomment it, whenever the "_sample_using_random" feature is the main focus of a given effort.
    # assert 2 <= returned_df.count() <= 3

    for val in returned_df.collect():
        assert val.date == datetime.date(2020, 1, 5)
Ejemplo n.º 9
0
def test_get_batch_with_split_on_whole_table(test_sparkdf):
    test_sparkdf = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf, splitter_method="_split_on_whole_table"
        )
    )
    assert test_sparkdf.count() == 120
    assert len(test_sparkdf.columns) == 10
Ejemplo n.º 10
0
def test_sample_using_random(test_sparkdf):
    sampled_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_sparkdf,
                             sampling_method="_sample_using_random"))
    # The test dataframe contains 10 columns and 120 rows.
    assert len(sampled_df.columns) == 10
    assert 0 <= sampled_df.count() <= 120
    # The sampling probability "p" used in "SparkDFExecutionEngine._sample_using_random()" is 0.1 (the equivalent of an
    # unfair coin with the 10% chance of coming up as "heads").  Hence, we should never get as much as 20% of the rows.
    assert sampled_df.count() < 25
Ejemplo n.º 11
0
def test_sample_using_a_list(test_df):
    sampled_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            sampling_method="_sample_using_a_list",
            sampling_kwargs={
                "column_name": "id",
                "value_list": [3, 5, 7, 11],
            },
        ))
    assert sampled_df.shape == (4, 10)
Ejemplo n.º 12
0
def test_sample_using_md5_wrong_hash_function_name(test_sparkdf):
    with pytest.raises(ge_exceptions.ExecutionEngineError):
        sampled_df = SparkDFExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_sparkdf,
                sampling_method="_sample_using_hash",
                sampling_kwargs={
                    "column_name": "date",
                    "hash_function_name": "I_wont_work",
                },
            ))
Ejemplo n.º 13
0
def test_sample_using_mod(test_df):
    sampled_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            sampling_method="_sample_using_mod",
            sampling_kwargs={
                "column_name": "id",
                "mod": 5,
                "value": 4,
            },
        ))
    assert sampled_df.shape == (24, 10)
Ejemplo n.º 14
0
def test_sample_using_a_list(test_sparkdf):
    sampled_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            sampling_method="_sample_using_a_list",
            sampling_kwargs={
                "column_name": "id",
                "value_list": [3, 5, 7, 11],
            },
        ))
    assert sampled_df.count() == 4
    assert len(sampled_df.columns) == 10
Ejemplo n.º 15
0
def test_get_batch_with_split_on_converted_datetime(test_sparkdf):
    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_converted_datetime",
            splitter_kwargs={
                "column_name": "timestamp",
                "partition_definition": {"timestamp": "2020-01-03"},
            },
        )
    )
    assert split_df.count() == 2
    assert len(split_df.columns) == 10
Ejemplo n.º 16
0
def test_sample_using_mod(test_sparkdf):
    sampled_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            sampling_method="_sample_using_mod",
            sampling_kwargs={
                "column_name": "id",
                "mod": 5,
                "value": 4,
            },
        ))
    assert sampled_df.count() == 24
    assert len(sampled_df.columns) == 10
Ejemplo n.º 17
0
def test_get_batch_with_split_on_converted_datetime(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_converted_datetime",
            splitter_kwargs={
                "column_name": "timestamp",
                "partition_definition": {
                    "timestamp": "2020-01-30"
                },
            },
        ))
    assert (split_df).shape == (3, 10)
Ejemplo n.º 18
0
def test_get_batch_with_split_on_mod_integer(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_mod_integer",
            splitter_kwargs={
                "column_name": "id",
                "mod": 10,
                "partition_definition": {
                    "id": 5
                },
            },
        ))
    assert split_df.shape == (12, 10)
    assert split_df.id.min() == 5
    assert split_df.id.max() == 115
Ejemplo n.º 19
0
def test_basic_setup(spark_session):
    pd_df = pd.DataFrame({"x": range(10)})
    df = spark_session.createDataFrame(
        [
            tuple(None if isinstance(x, (float, int)) and np.isnan(x) else x
                  for x in record.tolist())
            for record in pd_df.to_records(index=False)
        ],
        pd_df.columns.tolist(),
    )
    batch_data = SparkDFExecutionEngine().get_batch_data(
        batch_spec=RuntimeDataBatchSpec(
            batch_data=df,
            data_asset_name="DATA_ASSET",
        ))
    assert batch_data is not None
Ejemplo n.º 20
0
def test_get_batch_with_split_on_hashed_column_incorrect_hash_function_name(
    test_sparkdf, ):
    with pytest.raises(ge_exceptions.ExecutionEngineError):
        split_df = SparkDFExecutionEngine().get_batch_data(
            RuntimeDataBatchSpec(
                batch_data=test_sparkdf,
                splitter_method="_split_on_hashed_column",
                splitter_kwargs={
                    "column_name": "favorite_color",
                    "hash_digits": 1,
                    "hash_function_name": "I_wont_work",
                    "partition_definition": {
                        "hash_value": "a",
                    },
                },
            ))
Ejemplo n.º 21
0
def test_get_batch_with_split_on_hashed_column(test_sparkdf):
    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_hashed_column",
            splitter_kwargs={
                "column_name": "favorite_color",
                "hash_digits": 1,
                "hash_function_name": "sha256",
                "partition_definition": {
                    "hash_value": "a",
                },
            },
        ))
    assert split_df.count() == 8
    assert len(split_df.columns) == 10
Ejemplo n.º 22
0
def test_sample_using_md5(test_sparkdf):
    sampled_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            sampling_method="_sample_using_hash",
            sampling_kwargs={
                "column_name": "date",
                "hash_function_name": "md5",
            },
        )
    )
    assert sampled_df.count() == 10
    assert len(sampled_df.columns) == 10

    collected = sampled_df.collect()
    for val in collected:
        assert val.date in [datetime.date(2020, 1, 15), datetime.date(2020, 1, 29)]
Ejemplo n.º 23
0
def test_get_batch_with_split_on_divided_integer(test_sparkdf):
    split_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_sparkdf,
            splitter_method="_split_on_divided_integer",
            splitter_kwargs={
                "column_name": "id",
                "divisor": 10,
                "partition_definition": {"id": 5},
            },
        )
    )
    assert split_df.count() == 10
    assert len(split_df.columns) == 10
    max_result = split_df.select([F.max("id")])
    assert max_result.collect()[0]["max(id)"] == 59
    min_result = split_df.select([F.min("id")])
    assert min_result.collect()[0]["min(id)"] == 50
Ejemplo n.º 24
0
def test_get_batch_with_split_on_divided_integer_and_sample_on_list(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(
            batch_data=test_df,
            splitter_method="_split_on_divided_integer",
            splitter_kwargs={
                "column_name": "id",
                "divisor": 10,
                "partition_definition": {
                    "id": 5
                },
            },
            sampling_method="_sample_using_mod",
            sampling_kwargs={
                "column_name": "id",
                "mod": 5,
                "value": 4,
            },
        ))
    assert split_df.shape == (2, 10)
    assert split_df.id.min() == 54
    assert split_df.id.max() == 59
Ejemplo n.º 25
0
def test_get_batch_empty_sampler(test_sparkdf):
    sampled_df = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_sparkdf, sampling_method=None)
    )
    assert sampled_df.count() == 120
    assert len(sampled_df.columns) == 10
Ejemplo n.º 26
0
def test_sample_using_random(test_df):
    random.seed(1)
    sampled_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_df,
                             sampling_method="_sample_using_random"))
    assert sampled_df.shape == (13, 10)
Ejemplo n.º 27
0
def test_get_batch_with_split_on_whole_table(test_df):
    split_df = PandasExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_df,
                             splitter_method="_split_on_whole_table"))
    assert split_df.shape == (120, 10)
Ejemplo n.º 28
0
def test_get_batch_data(test_sparkdf):
    test_sparkdf = SparkDFExecutionEngine().get_batch_data(
        RuntimeDataBatchSpec(batch_data=test_sparkdf, data_asset_name="DATA_ASSET")
    )
    assert test_sparkdf.count() == 120
    assert len(test_sparkdf.columns) == 10