def test_get_batch_with_split_on_whole_table_s3(spark_session): def mocked_get_reader_function(*args, **kwargs): def mocked_reader_function(*args, **kwargs): pd_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, None]}) df = spark_session.createDataFrame( [ tuple( None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist() ) for record in pd_df.to_records(index=False) ], pd_df.columns.tolist(), ) return df return mocked_reader_function spark_engine = SparkDFExecutionEngine() spark_engine._get_reader_fn = mocked_get_reader_function test_sparkdf = spark_engine.get_batch_data( S3BatchSpec( s3="s3://bucket/test/test.csv", reader_method="csv", reader_options={"header": True}, splitter_method="_split_on_whole_table", ) ) assert test_sparkdf.count() == 4 assert len(test_sparkdf.columns) == 2
def test_get_batch_with_split_on_whole_table_s3(): region_name: str = "us-east-1" bucket: str = "test_bucket" conn = boto3.resource("s3", region_name=region_name) conn.create_bucket(Bucket=bucket) client = boto3.client("s3", region_name=region_name) test_df: pd.DataFrame = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) keys: List[str] = [ "path/A-100.csv", "path/A-101.csv", "directory/B-1.csv", "directory/B-2.csv", ] for key in keys: client.put_object(Bucket=bucket, Body=test_df.to_csv(index=False).encode("utf-8"), Key=key) path = "path/A-100.csv" full_path = f"s3a://{os.path.join(bucket, path)}" test_df = PandasExecutionEngine().get_batch_data(batch_spec=S3BatchSpec( path=full_path, reader_method="read_csv", splitter_method="_split_on_whole_table", )) assert test_df.shape == (2, 2) # if S3 was not configured execution_engine_no_s3 = PandasExecutionEngine() execution_engine_no_s3._s3 = None with pytest.raises(ge_exceptions.ExecutionEngineError): execution_engine_no_s3.get_batch_data(batch_spec=S3BatchSpec( path=full_path, reader_method="read_csv", splitter_method="_split_on_whole_table", ))
def test_get_batch_with_split_on_whole_table_s3_with_configured_asset_s3_data_connector( ): region_name: str = "us-east-1" bucket: str = "test_bucket" conn = boto3.resource("s3", region_name=region_name) conn.create_bucket(Bucket=bucket) client = boto3.client("s3", region_name=region_name) test_df: pd.DataFrame = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) keys: List[str] = [ "path/A-100.csv", "path/A-101.csv", "directory/B-1.csv", "directory/B-2.csv", ] for key in keys: client.put_object(Bucket=bucket, Body=test_df.to_csv(index=False).encode("utf-8"), Key=key) path = "path/A-100.csv" full_path = f"s3a://{os.path.join(bucket, path)}" my_data_connector = ConfiguredAssetS3DataConnector( name="my_data_connector", datasource_name="FAKE_DATASOURCE_NAME", default_regex={ "pattern": "alpha-(.*)\\.csv", "group_names": ["index"], }, bucket=bucket, prefix="", assets={"alpha": {}}, ) test_df = PandasExecutionEngine().get_batch_data(batch_spec=S3BatchSpec( path=full_path, reader_method="read_csv", splitter_method="_split_on_whole_table", )) assert test_df.shape == (2, 2)