def test_get_batch_with_split_on_multi_column_values(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_multi_column_values", splitter_kwargs={ "column_names": ["y", "m", "d"], "partition_definition": { "y": 2020, "m": 1, "d": 5, }, }, )) assert split_df.shape == (4, 10) assert (split_df.date == datetime.date(2020, 1, 5)).all() with pytest.raises(ValueError): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_multi_column_values", splitter_kwargs={ "column_names": ["I", "dont", "exist"], "partition_definition": { "y": 2020, "m": 1, "d": 5, }, }, ))
def test_get_batch_with_split_on_column_value(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_column_value", splitter_kwargs={ "column_name": "batch_id", "partition_definition": { "batch_id": 2 }, }, )) assert split_df.shape == (12, 10) assert (split_df.batch_id == 2).all() split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_column_value", splitter_kwargs={ "column_name": "date", "partition_definition": { "date": datetime.date(2020, 1, 30) }, }, )) assert (split_df).shape == (3, 10)
def test_get_batch_with_split_on_multi_column_values(test_sparkdf): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_multi_column_values", splitter_kwargs={ "column_names": ["y", "m", "d"], "partition_definition": { "y": 2020, "m": 1, "d": 5, }, }, )) assert split_df.count() == 4 assert len(split_df.columns) == 10 collected = split_df.collect() for val in collected: assert val.date == datetime.date(2020, 1, 5) with pytest.raises(ValueError): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_multi_column_values", splitter_kwargs={ "column_names": ["I", "dont", "exist"], "partition_definition": { "y": 2020, "m": 1, "d": 5, }, }, ))
def test_sample_using_md5(test_df): with pytest.raises(ge_exceptions.ExecutionEngineError): sampled_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, sampling_method="_sample_using_hash", sampling_kwargs={ "column_name": "date", "hash_function_name": "I_am_not_valid", }, )) sampled_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, sampling_method="_sample_using_hash", sampling_kwargs={ "column_name": "date", "hash_function_name": "md5" }, )) assert sampled_df.shape == (10, 10) assert sampled_df.date.isin([ datetime.date(2020, 1, 15), datetime.date(2020, 1, 29), ]).all()
def test_get_batch_with_split_on_column_value(test_sparkdf): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_column_value", splitter_kwargs={ "column_name": "batch_id", "partition_definition": {"batch_id": 2}, }, ) ) assert test_sparkdf.count() == 120 assert len(test_sparkdf.columns) == 10 collected = split_df.collect() for val in collected: assert val.batch_id == 2 split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_column_value", splitter_kwargs={ "column_name": "date", "partition_definition": {"date": datetime.date(2020, 1, 30)}, }, ) ) assert split_df.count() == 3 assert len(split_df.columns) == 10
def test_get_batch_with_split_on_hashed_column(test_df): with pytest.raises(ge_exceptions.ExecutionEngineError): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_hashed_column", splitter_kwargs={ "column_name": "favorite_color", "hash_digits": 1, "partition_definition": { "hash_value": "a", }, "hash_function_name": "I_am_not_valid", }, )) split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_hashed_column", splitter_kwargs={ "column_name": "favorite_color", "hash_digits": 1, "partition_definition": { "hash_value": "a", }, "hash_function_name": "sha256", }, )) assert split_df.shape == (8, 10)
def test_get_batch_data(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_df, )) assert split_df.shape == (120, 10) # No dataset passed to RuntimeDataBatchSpec with pytest.raises(ge_exceptions.InvalidBatchSpecError): PandasExecutionEngine().get_batch_data(RuntimeDataBatchSpec())
def test_split_on_multi_column_values_and_sample_using_random(test_sparkdf): returned_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_multi_column_values", splitter_kwargs={ "column_names": ["y", "m", "d"], "partition_definition": { "y": 2020, "m": 1, "d": 5, }, }, sampling_method="_sample_using_random", sampling_kwargs={ "p": 0.5, }, ) ) # The test dataframe contains 10 columns and 120 rows. assert len(returned_df.columns) == 10 # The number of returned rows corresponding to the value of "partition_definition" above is 4. assert 0 <= returned_df.count() <= 4 # The sampling probability "p" used in "SparkDFExecutionEngine._sample_using_random()" is 0.5 (the equivalent of a # fair coin with the 50% chance of coming up as "heads"). Hence, on average we should get 50% of the rows, which is # 2; however, for such a small sample (of 4 rows), the number of rows returned by an individual run can deviate from # this average. Still, in the majority of trials, the number of rows should not be fewer than 2 or greater than 3. # The assertion in the next line, supporting this reasoning, is commented out to insure zero failures. Developers # are encouraged to uncomment it, whenever the "_sample_using_random" feature is the main focus of a given effort. # assert 2 <= returned_df.count() <= 3 for val in returned_df.collect(): assert val.date == datetime.date(2020, 1, 5)
def test_get_batch_with_split_on_whole_table(test_sparkdf): test_sparkdf = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_whole_table" ) ) assert test_sparkdf.count() == 120 assert len(test_sparkdf.columns) == 10
def test_sample_using_random(test_sparkdf): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_sparkdf, sampling_method="_sample_using_random")) # The test dataframe contains 10 columns and 120 rows. assert len(sampled_df.columns) == 10 assert 0 <= sampled_df.count() <= 120 # The sampling probability "p" used in "SparkDFExecutionEngine._sample_using_random()" is 0.1 (the equivalent of an # unfair coin with the 10% chance of coming up as "heads"). Hence, we should never get as much as 20% of the rows. assert sampled_df.count() < 25
def test_sample_using_a_list(test_df): sampled_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, sampling_method="_sample_using_a_list", sampling_kwargs={ "column_name": "id", "value_list": [3, 5, 7, 11], }, )) assert sampled_df.shape == (4, 10)
def test_sample_using_md5_wrong_hash_function_name(test_sparkdf): with pytest.raises(ge_exceptions.ExecutionEngineError): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, sampling_method="_sample_using_hash", sampling_kwargs={ "column_name": "date", "hash_function_name": "I_wont_work", }, ))
def test_sample_using_mod(test_df): sampled_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, sampling_method="_sample_using_mod", sampling_kwargs={ "column_name": "id", "mod": 5, "value": 4, }, )) assert sampled_df.shape == (24, 10)
def test_sample_using_a_list(test_sparkdf): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, sampling_method="_sample_using_a_list", sampling_kwargs={ "column_name": "id", "value_list": [3, 5, 7, 11], }, )) assert sampled_df.count() == 4 assert len(sampled_df.columns) == 10
def test_get_batch_with_split_on_converted_datetime(test_sparkdf): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_converted_datetime", splitter_kwargs={ "column_name": "timestamp", "partition_definition": {"timestamp": "2020-01-03"}, }, ) ) assert split_df.count() == 2 assert len(split_df.columns) == 10
def test_sample_using_mod(test_sparkdf): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, sampling_method="_sample_using_mod", sampling_kwargs={ "column_name": "id", "mod": 5, "value": 4, }, )) assert sampled_df.count() == 24 assert len(sampled_df.columns) == 10
def test_get_batch_with_split_on_converted_datetime(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_converted_datetime", splitter_kwargs={ "column_name": "timestamp", "partition_definition": { "timestamp": "2020-01-30" }, }, )) assert (split_df).shape == (3, 10)
def test_get_batch_with_split_on_mod_integer(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_mod_integer", splitter_kwargs={ "column_name": "id", "mod": 10, "partition_definition": { "id": 5 }, }, )) assert split_df.shape == (12, 10) assert split_df.id.min() == 5 assert split_df.id.max() == 115
def test_basic_setup(spark_session): pd_df = pd.DataFrame({"x": range(10)}) df = spark_session.createDataFrame( [ tuple(None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist()) for record in pd_df.to_records(index=False) ], pd_df.columns.tolist(), ) batch_data = SparkDFExecutionEngine().get_batch_data( batch_spec=RuntimeDataBatchSpec( batch_data=df, data_asset_name="DATA_ASSET", )) assert batch_data is not None
def test_get_batch_with_split_on_hashed_column_incorrect_hash_function_name( test_sparkdf, ): with pytest.raises(ge_exceptions.ExecutionEngineError): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_hashed_column", splitter_kwargs={ "column_name": "favorite_color", "hash_digits": 1, "hash_function_name": "I_wont_work", "partition_definition": { "hash_value": "a", }, }, ))
def test_get_batch_with_split_on_hashed_column(test_sparkdf): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_hashed_column", splitter_kwargs={ "column_name": "favorite_color", "hash_digits": 1, "hash_function_name": "sha256", "partition_definition": { "hash_value": "a", }, }, )) assert split_df.count() == 8 assert len(split_df.columns) == 10
def test_sample_using_md5(test_sparkdf): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, sampling_method="_sample_using_hash", sampling_kwargs={ "column_name": "date", "hash_function_name": "md5", }, ) ) assert sampled_df.count() == 10 assert len(sampled_df.columns) == 10 collected = sampled_df.collect() for val in collected: assert val.date in [datetime.date(2020, 1, 15), datetime.date(2020, 1, 29)]
def test_get_batch_with_split_on_divided_integer(test_sparkdf): split_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_sparkdf, splitter_method="_split_on_divided_integer", splitter_kwargs={ "column_name": "id", "divisor": 10, "partition_definition": {"id": 5}, }, ) ) assert split_df.count() == 10 assert len(split_df.columns) == 10 max_result = split_df.select([F.max("id")]) assert max_result.collect()[0]["max(id)"] == 59 min_result = split_df.select([F.min("id")]) assert min_result.collect()[0]["min(id)"] == 50
def test_get_batch_with_split_on_divided_integer_and_sample_on_list(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec( batch_data=test_df, splitter_method="_split_on_divided_integer", splitter_kwargs={ "column_name": "id", "divisor": 10, "partition_definition": { "id": 5 }, }, sampling_method="_sample_using_mod", sampling_kwargs={ "column_name": "id", "mod": 5, "value": 4, }, )) assert split_df.shape == (2, 10) assert split_df.id.min() == 54 assert split_df.id.max() == 59
def test_get_batch_empty_sampler(test_sparkdf): sampled_df = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_sparkdf, sampling_method=None) ) assert sampled_df.count() == 120 assert len(sampled_df.columns) == 10
def test_sample_using_random(test_df): random.seed(1) sampled_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_df, sampling_method="_sample_using_random")) assert sampled_df.shape == (13, 10)
def test_get_batch_with_split_on_whole_table(test_df): split_df = PandasExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_df, splitter_method="_split_on_whole_table")) assert split_df.shape == (120, 10)
def test_get_batch_data(test_sparkdf): test_sparkdf = SparkDFExecutionEngine().get_batch_data( RuntimeDataBatchSpec(batch_data=test_sparkdf, data_asset_name="DATA_ASSET") ) assert test_sparkdf.count() == 120 assert len(test_sparkdf.columns) == 10