예제 #1
0
def test_get_batch_with_split_on_whole_table_s3(spark_session):
    def mocked_get_reader_function(*args, **kwargs):
        def mocked_reader_function(*args, **kwargs):
            pd_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, None]})
            df = spark_session.createDataFrame(
                [
                    tuple(
                        None if isinstance(x, (float, int)) and np.isnan(x) else x
                        for x in record.tolist()
                    )
                    for record in pd_df.to_records(index=False)
                ],
                pd_df.columns.tolist(),
            )
            return df

        return mocked_reader_function

    spark_engine = SparkDFExecutionEngine()
    spark_engine._get_reader_fn = mocked_get_reader_function

    test_sparkdf = spark_engine.get_batch_data(
        S3BatchSpec(
            s3="s3://bucket/test/test.csv",
            reader_method="csv",
            reader_options={"header": True},
            splitter_method="_split_on_whole_table",
        )
    )
    assert test_sparkdf.count() == 4
    assert len(test_sparkdf.columns) == 2
예제 #2
0
def test_get_batch_with_split_on_whole_table_s3():
    region_name: str = "us-east-1"
    bucket: str = "test_bucket"
    conn = boto3.resource("s3", region_name=region_name)
    conn.create_bucket(Bucket=bucket)
    client = boto3.client("s3", region_name=region_name)

    test_df: pd.DataFrame = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
    keys: List[str] = [
        "path/A-100.csv",
        "path/A-101.csv",
        "directory/B-1.csv",
        "directory/B-2.csv",
    ]
    for key in keys:
        client.put_object(Bucket=bucket,
                          Body=test_df.to_csv(index=False).encode("utf-8"),
                          Key=key)

    path = "path/A-100.csv"
    full_path = f"s3a://{os.path.join(bucket, path)}"
    test_df = PandasExecutionEngine().get_batch_data(batch_spec=S3BatchSpec(
        path=full_path,
        reader_method="read_csv",
        splitter_method="_split_on_whole_table",
    ))
    assert test_df.shape == (2, 2)

    # if S3 was not configured
    execution_engine_no_s3 = PandasExecutionEngine()
    execution_engine_no_s3._s3 = None
    with pytest.raises(ge_exceptions.ExecutionEngineError):
        execution_engine_no_s3.get_batch_data(batch_spec=S3BatchSpec(
            path=full_path,
            reader_method="read_csv",
            splitter_method="_split_on_whole_table",
        ))
예제 #3
0
def test_get_batch_with_split_on_whole_table_s3_with_configured_asset_s3_data_connector(
):
    region_name: str = "us-east-1"
    bucket: str = "test_bucket"
    conn = boto3.resource("s3", region_name=region_name)
    conn.create_bucket(Bucket=bucket)
    client = boto3.client("s3", region_name=region_name)

    test_df: pd.DataFrame = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
    keys: List[str] = [
        "path/A-100.csv",
        "path/A-101.csv",
        "directory/B-1.csv",
        "directory/B-2.csv",
    ]
    for key in keys:
        client.put_object(Bucket=bucket,
                          Body=test_df.to_csv(index=False).encode("utf-8"),
                          Key=key)
    path = "path/A-100.csv"
    full_path = f"s3a://{os.path.join(bucket, path)}"

    my_data_connector = ConfiguredAssetS3DataConnector(
        name="my_data_connector",
        datasource_name="FAKE_DATASOURCE_NAME",
        default_regex={
            "pattern": "alpha-(.*)\\.csv",
            "group_names": ["index"],
        },
        bucket=bucket,
        prefix="",
        assets={"alpha": {}},
    )

    test_df = PandasExecutionEngine().get_batch_data(batch_spec=S3BatchSpec(
        path=full_path,
        reader_method="read_csv",
        splitter_method="_split_on_whole_table",
    ))
    assert test_df.shape == (2, 2)