def test_spark_app_error(tag, role, image_uri, sagemaker_session):
    """Submits a PySpark app which is scripted to exit with error code 1"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-app-error",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    try:
        spark.run(
            submit_app=
            "test/resources/code/python/py_spark_app_error/py_spark_app_error.py",
            wait=True,
            logs=False,
        )
    except Exception:
        pass  # this job is expected to fail
    processing_job = spark.latest_job

    describe_response = processing_job.describe()
    assert "AlgorithmError: See job logs for more information" == describe_response[
        "FailureReason"]
    assert "Algorithm Error: (caused by CalledProcessError)" in describe_response[
        "ExitMessage"]
    assert "returned non-zero exit status 1" in describe_response[
        "ExitMessage"]
def test_sagemaker_pyspark_sse_s3(tag, role, image_uri, sagemaker_session,
                                  region, sagemaker_client):
    """Test that Spark container can read and write S3 data encrypted with SSE-S3 (default AES256 encryption)"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    input_data_key = f"spark/input/sales/{timestamp}/data.jsonl"
    input_data_uri = f"s3://{bucket}/{input_data_key}"
    output_data_uri = f"s3://{bucket}/spark/output/sales/{timestamp}"
    s3_client = sagemaker_session.boto_session.client("s3", region_name=region)
    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        s3_client.put_object(Body=body,
                             Bucket=bucket,
                             Key=input_data_key,
                             ServerSideEncryption="AES256")

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration={
            "Classification": "core-site",
            "Properties": {
                "fs.s3a.server-side-encryption-algorithm": "AES256"
            },
        },
    )
    processing_job = spark.latest_job

    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_sagemaker_spark_processor_default_tag(spark_version, role,
                                               sagemaker_session,
                                               sagemaker_client):
    """Test that spark processor works with default tag"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=spark_version,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = spark.sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    output_data_uri = "s3://{}/spark/output/sales/{}".format(bucket, timestamp)
    spark_event_logs_key_prefix = "spark/spark-events/{}".format(timestamp)
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket,
                                                  spark_event_logs_key_prefix)

    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        input_data_uri = "s3://{}/spark/input/data.jsonl".format(bucket)
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        spark_event_logs_s3_uri=spark_event_logs_s3_uri,
        wait=True,
    )

    processing_job = spark.latest_job
    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_history_server_with_expected_failure(tag, role, image_uri, sagemaker_session, caplog):
    spark = PySparkProcessor(
        base_job_name="sm-spark",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    caplog.set_level(logging.ERROR)
    spark.start_history_server(spark_event_logs_s3_uri="invalids3uri")
    response = _request_with_retry(HISTORY_SERVER_ENDPOINT, max_retries=5)
    assert response is None
    assert "History server failed to start. Please run 'docker logs history_server' to see logs" in caplog.text
def spark_py_processor(sagemaker_session, cpu_instance_type):
    spark_py_processor = PySparkProcessor(
        role="SageMakerRole",
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="2.4",
    )

    return spark_py_processor
Beispiel #6
0
def test_pyspark_processor_instantiation(sagemaker_session):
    # This just tests that the import is right and that the processor can be instantiated
    # Functionality is tested in project root container directory.
    PySparkProcessor(
        base_job_name="sm-spark",
        role="AmazonSageMaker-ExecutionRole",
        framework_version="2.4",
        instance_count=1,
        instance_type="ml.c5.xlarge",
        sagemaker_session=sagemaker_session,
    )
def test_history_server(tag, role, image_uri, sagemaker_session, region):
    spark = PySparkProcessor(
        base_job_name="sm-spark",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = sagemaker_session.default_bucket()
    spark_event_logs_key_prefix = "spark/spark-history-fs"
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket, spark_event_logs_key_prefix)
    spark_event_log_local_path = "test/resources/data/files/sample_spark_event_logs"
    file_name = "sample_spark_event_logs"
    file_size = os.path.getsize(spark_event_log_local_path)

    with open("test/resources/data/files/sample_spark_event_logs") as data:
        body = data.read()
        S3Uploader.upload_string_as_file_body(
            body=body, desired_s3_uri=f"{spark_event_logs_s3_uri}/{file_name}", sagemaker_session=sagemaker_session,
        )

    _wait_for_file_to_be_uploaded(region, bucket, spark_event_logs_key_prefix, file_name, file_size)
    spark.start_history_server(spark_event_logs_s3_uri=spark_event_logs_s3_uri)

    try:
        response = _request_with_retry(HISTORY_SERVER_ENDPOINT)
        assert response.status == 200

        response = _request_with_retry(f"{HISTORY_SERVER_ENDPOINT}{SPARK_APPLICATION_URL_SUFFIX}", max_retries=15)
        print(f"Subpage response status code: {response.status}")
    finally:
        spark.terminate_history_server()
Beispiel #8
0
def py_spark_processor(sagemaker_session) -> PySparkProcessor:
    spark = PySparkProcessor(
        base_job_name="sm-spark",
        role="AmazonSageMaker-ExecutionRole",
        framework_version="2.4",
        instance_count=1,
        instance_type="ml.c5.xlarge",
        image_uri=
        "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
        sagemaker_session=sagemaker_session,
    )

    return spark
Beispiel #9
0
def test_history_server(tag, role, image_uri, sagemaker_session, region):
    spark = PySparkProcessor(
        base_job_name="sm-spark",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = sagemaker_session.default_bucket()
    spark_event_logs_key_prefix = "spark/spark-history-fs"
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket,
                                                  spark_event_logs_key_prefix)
    spark_event_log_local_path = "test/resources/data/files/sample_spark_event_logs"
    file_name = "sample_spark_event_logs"
    file_size = os.path.getsize(spark_event_log_local_path)

    with open("test/resources/data/files/sample_spark_event_logs") as data:
        body = data.read()
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=f"{spark_event_logs_s3_uri}/{file_name}",
            sagemaker_session=sagemaker_session,
        )

    _wait_for_file_to_be_uploaded(region, bucket, spark_event_logs_key_prefix,
                                  file_name, file_size)
    spark.start_history_server(spark_event_logs_s3_uri=spark_event_logs_s3_uri)

    try:
        response = _request_with_retry(HISTORY_SERVER_ENDPOINT)
        assert response.status == 200

        # spark has redirect behavior, this request verify that page navigation works with redirect
        response = _request_with_retry(
            f"{HISTORY_SERVER_ENDPOINT}{SPARK_APPLICATION_URL_SUFFIX}")
        if response.status != 200:
            print(subprocess.run(["docker", "logs", "history_server"]))

        assert response.status == 200

        html_content = response.data.decode("utf-8")
        assert "Completed Jobs (4)" in html_content
        assert "collect at /opt/ml/processing/input/code/test_long_duration.py:32" in html_content
    finally:
        spark.terminate_history_server()
Beispiel #10
0
def test_configuration_validation(config, expected, sagemaker_session) -> None:
    # This just tests that the import is right and that the processor can be instantiated
    # Functionality is tested in project root container directory.
    spark = PySparkProcessor(
        base_job_name="sm-spark",
        role="AmazonSageMaker-ExecutionRole",
        framework_version="2.4",
        instance_count=1,
        instance_type="ml.c5.xlarge",
        sagemaker_session=sagemaker_session,
    )

    if expected is None:
        spark._validate_configuration(config)
    else:
        with pytest.raises(expected):
            spark._validate_configuration(config)
def test_two_processing_job_depends_on(
    sagemaker_session,
    role,
    pipeline_name,
    region_name,
    cpu_instance_type,
):
    instance_count = ParameterInteger(name="InstanceCount", default_value=2)
    script_path = os.path.join(DATA_DIR, "dummy_script.py")

    pyspark_processor = PySparkProcessor(
        base_job_name="sm-spark",
        framework_version="2.4",
        role=role,
        instance_count=instance_count,
        instance_type=cpu_instance_type,
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    spark_run_args = pyspark_processor.get_run_args(
        submit_app=script_path,
        arguments=[
            "--s3_input_bucket",
            sagemaker_session.default_bucket(),
            "--s3_input_key_prefix",
            "spark-input",
            "--s3_output_bucket",
            sagemaker_session.default_bucket(),
            "--s3_output_key_prefix",
            "spark-output",
        ],
    )

    step_pyspark_1 = ProcessingStep(
        name="pyspark-process-1",
        processor=pyspark_processor,
        inputs=spark_run_args.inputs,
        outputs=spark_run_args.outputs,
        job_arguments=spark_run_args.arguments,
        code=spark_run_args.code,
    )

    step_pyspark_2 = ProcessingStep(
        name="pyspark-process-2",
        depends_on=[step_pyspark_1],
        processor=pyspark_processor,
        inputs=spark_run_args.inputs,
        outputs=spark_run_args.outputs,
        job_arguments=spark_run_args.arguments,
        code=spark_run_args.code,
    )

    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[instance_count],
        steps=[step_pyspark_1, step_pyspark_2],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            create_arn,
        )

        pipeline.parameters = [
            ParameterInteger(name="InstanceCount", default_value=1)
        ]
        response = pipeline.update(role)
        update_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            update_arn,
        )

        execution = pipeline.start(parameters={})
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
            execution.arn,
        )

        response = execution.describe()
        assert response["PipelineArn"] == create_arn

        try:
            execution.wait(delay=60)
        except WaiterError:
            pass

        execution_steps = execution.list_steps()
        assert len(execution_steps) == 2
        time_stamp = {}
        for execution_step in execution_steps:
            name = execution_step["StepName"]
            if name == "pyspark-process-1":
                time_stamp[name] = execution_step["EndTime"]
            else:
                time_stamp[name] = execution_step["StartTime"]
        assert time_stamp["pyspark-process-1"] < time_stamp["pyspark-process-2"]
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_one_step_pyspark_processing_pipeline(
    sagemaker_session,
    role,
    cpu_instance_type,
    pipeline_name,
    region_name,
):
    instance_count = ParameterInteger(name="InstanceCount", default_value=2)
    script_path = os.path.join(DATA_DIR, "dummy_script.py")

    cache_config = CacheConfig(enable_caching=True, expire_after="T30m")

    pyspark_processor = PySparkProcessor(
        base_job_name="sm-spark",
        framework_version="2.4",
        role=role,
        instance_count=instance_count,
        instance_type=cpu_instance_type,
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    spark_run_args = pyspark_processor.get_run_args(
        submit_app=script_path,
        arguments=[
            "--s3_input_bucket",
            sagemaker_session.default_bucket(),
            "--s3_input_key_prefix",
            "spark-input",
            "--s3_output_bucket",
            sagemaker_session.default_bucket(),
            "--s3_output_key_prefix",
            "spark-output",
        ],
    )

    step_pyspark = ProcessingStep(
        name="pyspark-process",
        processor=pyspark_processor,
        inputs=spark_run_args.inputs,
        outputs=spark_run_args.outputs,
        job_arguments=spark_run_args.arguments,
        code=spark_run_args.code,
        cache_config=cache_config,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[instance_count],
        steps=[step_pyspark],
        sagemaker_session=sagemaker_session,
    )

    try:
        # NOTE: We should exercise the case when role used in the pipeline execution is
        # different than that required of the steps in the pipeline itself. The role in
        # the pipeline definition needs to create training and processing jobs and other
        # sagemaker entities. However, the jobs created in the steps themselves execute
        # under a potentially different role, often requiring access to S3 and other
        # artifacts not required to during creation of the jobs in the pipeline steps.
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            create_arn,
        )

        pipeline.parameters = [
            ParameterInteger(name="InstanceCount", default_value=1)
        ]
        response = pipeline.update(role)
        update_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            update_arn,
        )

        execution = pipeline.start(parameters={})
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
            execution.arn,
        )

        response = execution.describe()
        assert response["PipelineArn"] == create_arn

        # Check CacheConfig
        response = json.loads(
            pipeline.describe()
            ["PipelineDefinition"])["Steps"][0]["CacheConfig"]
        assert response["Enabled"] == cache_config.enable_caching
        assert response["ExpireAfter"] == cache_config.expire_after

        try:
            execution.wait(delay=30, max_attempts=3)
        except WaiterError:
            pass
        execution_steps = execution.list_steps()
        assert len(execution_steps) == 1
        assert execution_steps[0]["StepName"] == "pyspark-process"
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
Beispiel #13
0
            SparkJarProcessor(
                role=sagemaker.get_execution_role(),
                framework_version="2.4",
                instance_count=1,
                instance_type=INSTANCE_TYPE,
            ),
            {
                "submit_app": "s3://my-jar",
                "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
                "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
            },
        ),
        (
            PySparkProcessor(
                role=sagemaker.get_execution_role(),
                framework_version="2.4",
                instance_count=1,
                instance_type=INSTANCE_TYPE,
            ),
            {
                "submit_app": "s3://my-jar",
                "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
            },
        ),
    ],
)
def test_processing_step_with_framework_processor(
    framework_processor, pipeline_session, processing_input, network_config
):

    processor, run_inputs = framework_processor
    processor.sagemaker_session = pipeline_session
Beispiel #14
0
from sagemaker.workflow.steps import ProcessingStep
from sagemaker.workflow.pipeline import Pipeline

sagemaker_role = sagemaker.get_execution_role()

# ###### #
# Params #
BUCKET = ''
PREFIX = ''

# ###### #
# Node 1 #
node_1_proc = PySparkProcessor(
    base_job_name='spark-proc-name',
    framework_version='2.4',
    role=sagemaker_role,
    instance_count=1,
    instance_type='ml.r5.8xlarge',
    env={'AWS_DEFAULT_REGION': boto3.Session().region_name},
    max_runtime_in_seconds=1800)

configuration = [{
    "Classification": "spark-defaults",
    "Properties": {
        "spark.executor.memory": "200g",
        "spark.driver.memory": "200g",
        "spark.executor.cores": "20",
        "spark.cores.memmaxory": "20"
    }
}]

node_1_run_args = node_1_proc.get_run_args(
def test_sagemaker_pyspark_multinode(tag, role, image_uri, configuration,
                                     sagemaker_session, region,
                                     sagemaker_client):
    """Test that basic multinode case works on 32KB of data"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = spark.sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    output_data_uri = "s3://{}/spark/output/sales/{}".format(bucket, timestamp)
    spark_event_logs_key_prefix = "spark/spark-events/{}".format(timestamp)
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket,
                                                  spark_event_logs_key_prefix)

    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        input_data_uri = "s3://{}/spark/input/data.jsonl".format(bucket)
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
        spark_event_logs_s3_uri=spark_event_logs_s3_uri,
        wait=False,
    )
    processing_job = spark.latest_job

    s3_client = boto3.client("s3", region_name=region)

    file_size = 0
    latest_file_size = None
    updated_times_count = 0
    time_out = time.time() + 900

    while not processing_job_not_fail_or_complete(sagemaker_client,
                                                  processing_job.job_name):
        response = s3_client.list_objects(Bucket=bucket,
                                          Prefix=spark_event_logs_key_prefix)
        if "Contents" in response:
            # somehow when call list_objects the first file size is always 0, this for loop
            # is to skip that.
            for event_log_file in response["Contents"]:
                if event_log_file["Size"] != 0:
                    print("\n##### Latest file size is " +
                          str(event_log_file["Size"]))
                    latest_file_size = event_log_file["Size"]

        # update the file size if it increased
        if latest_file_size and latest_file_size > file_size:
            print("\n##### S3 file updated.")
            updated_times_count += 1
            file_size = latest_file_size

        if time.time() > time_out:
            raise RuntimeError("Timeout")

        time.sleep(20)

    # verify that spark event logs are periodically written to s3
    print("\n##### file_size {} updated_times_count {}".format(
        file_size, updated_times_count))
    assert file_size != 0

    # Commenting this assert because it's flaky.
    # assert updated_times_count > 1

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_sagemaker_pyspark_sse_kms_s3(role, image_uri, sagemaker_session,
                                      region, sagemaker_client, account_id,
                                      partition):
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    # This test expected AWS managed s3 kms key to be present. The key will be in
    # KMS > AWS managed keys > aws/s3
    kms_key_id = None
    kms_client = sagemaker_session.boto_session.client("kms",
                                                       region_name=region)
    for alias in kms_client.list_aliases()["Aliases"]:
        if "s3" in alias["AliasName"]:
            kms_key_id = alias["TargetKeyId"]

    if not kms_key_id:
        raise ValueError(
            "AWS managed s3 kms key(alias: aws/s3) does not exist")

    bucket = sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    input_data_key = f"spark/input/sales/{timestamp}/data.jsonl"
    input_data_uri = f"s3://{bucket}/{input_data_key}"
    output_data_uri_prefix = f"spark/output/sales/{timestamp}"
    output_data_uri = f"s3://{bucket}/{output_data_uri_prefix}"
    s3_client = sagemaker_session.boto_session.client("s3", region_name=region)
    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        s3_client.put_object(Body=body,
                             Bucket=bucket,
                             Key=input_data_key,
                             ServerSideEncryption="aws:kms",
                             SSEKMSKeyId=kms_key_id)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration={
            "Classification": "core-site",
            "Properties": {
                "fs.s3a.server-side-encryption-algorithm":
                "SSE-KMS",
                "fs.s3a.server-side-encryption.key":
                f"arn:{partition}:kms:{region}:{account_id}:key/{kms_key_id}",
            },
        },
    )
    processing_job = spark.latest_job
    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    s3_objects = s3_client.list_objects(
        Bucket=bucket, Prefix=output_data_uri_prefix)["Contents"]
    assert len(s3_objects) != 0
    for s3_object in s3_objects:
        object_metadata = s3_client.get_object(Bucket=bucket,
                                               Key=s3_object["Key"])
        assert object_metadata["ServerSideEncryption"] == "aws:kms"
        assert object_metadata[
            "SSEKMSKeyId"] == f"arn:{partition}:kms:{region}:{account_id}:key/{kms_key_id}"
import boto3
import sagemaker

from sagemaker.spark.processing import PySparkProcessor

sm = boto3.Session().client(service_name='sagemaker')
sagemaker_role = sagemaker.get_execution_role()

# ############################ #
# Pyspark Processor definition #
spark_processor = PySparkProcessor(
    base_job_name='spark-proc-name',
    framework_version='2.4',
    role=sagemaker_role,
    instance_count=1,
    instance_type='ml.r5.8xlarge',
    env={'AWS_DEFAULT_REGION': boto3.Session().region_name},
    max_runtime_in_seconds=1800)

configuration = [{
    "Classification": "spark-defaults",
    "Properties": {
        "spark.executor.memory": "200g",
        "spark.driver.memory": "200g",
        "spark.executor.cores": "20",
        "spark.cores.memmaxory": "20"
    }
}]

# #################################### #
# Launch Pyspark Processor with script #