def test_reader_fn(spark_session): engine = SparkDFExecutionEngine() # Testing that can recognize basic csv file fn = engine._get_reader_fn(reader=spark_session.read, path="myfile.csv") assert "<bound method DataFrameReader.csv" in str(fn) # Ensuring that other way around works as well - reader_method should always override path fn_new = engine._get_reader_fn(reader=spark_session.read, reader_method="csv") assert "<bound method DataFrameReader.csv" in str(fn_new)
def test_get_batch_with_split_on_whole_table_s3(spark_session): def mocked_get_reader_function(*args, **kwargs): def mocked_reader_function(*args, **kwargs): pd_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, None]}) df = spark_session.createDataFrame( [ tuple( None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist() ) for record in pd_df.to_records(index=False) ], pd_df.columns.tolist(), ) return df return mocked_reader_function spark_engine = SparkDFExecutionEngine() spark_engine._get_reader_fn = mocked_get_reader_function test_sparkdf = spark_engine.get_batch_data( S3BatchSpec( s3="s3://bucket/test/test.csv", reader_method="csv", reader_options={"header": True}, splitter_method="_split_on_whole_table", ) ) assert test_sparkdf.count() == 4 assert len(test_sparkdf.columns) == 2