def test_spark_app_error(tag, role, image_uri, sagemaker_session):
    """Submits a PySpark app which is scripted to exit with error code 1"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-app-error",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    try:
        spark.run(
            submit_app=
            "test/resources/code/python/py_spark_app_error/py_spark_app_error.py",
            wait=True,
            logs=False,
        )
    except Exception:
        pass  # this job is expected to fail
    processing_job = spark.latest_job

    describe_response = processing_job.describe()
    assert "AlgorithmError: See job logs for more information" == describe_response[
        "FailureReason"]
    assert "Algorithm Error: (caused by CalledProcessError)" in describe_response[
        "ExitMessage"]
    assert "returned non-zero exit status 1" in describe_response[
        "ExitMessage"]
def test_sagemaker_pyspark_sse_s3(tag, role, image_uri, sagemaker_session,
                                  region, sagemaker_client):
    """Test that Spark container can read and write S3 data encrypted with SSE-S3 (default AES256 encryption)"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    input_data_key = f"spark/input/sales/{timestamp}/data.jsonl"
    input_data_uri = f"s3://{bucket}/{input_data_key}"
    output_data_uri = f"s3://{bucket}/spark/output/sales/{timestamp}"
    s3_client = sagemaker_session.boto_session.client("s3", region_name=region)
    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        s3_client.put_object(Body=body,
                             Bucket=bucket,
                             Key=input_data_key,
                             ServerSideEncryption="AES256")

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration={
            "Classification": "core-site",
            "Properties": {
                "fs.s3a.server-side-encryption-algorithm": "AES256"
            },
        },
    )
    processing_job = spark.latest_job

    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_sagemaker_spark_processor_default_tag(spark_version, role,
                                               sagemaker_session,
                                               sagemaker_client):
    """Test that spark processor works with default tag"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=spark_version,
        role=role,
        instance_count=1,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = spark.sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    output_data_uri = "s3://{}/spark/output/sales/{}".format(bucket, timestamp)
    spark_event_logs_key_prefix = "spark/spark-events/{}".format(timestamp)
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket,
                                                  spark_event_logs_key_prefix)

    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        input_data_uri = "s3://{}/spark/input/data.jsonl".format(bucket)
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        spark_event_logs_s3_uri=spark_event_logs_s3_uri,
        wait=True,
    )

    processing_job = spark.latest_job
    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_sagemaker_pyspark_multinode(tag, role, image_uri, configuration,
                                     sagemaker_session, region,
                                     sagemaker_client):
    """Test that basic multinode case works on 32KB of data"""
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )
    bucket = spark.sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    output_data_uri = "s3://{}/spark/output/sales/{}".format(bucket, timestamp)
    spark_event_logs_key_prefix = "spark/spark-events/{}".format(timestamp)
    spark_event_logs_s3_uri = "s3://{}/{}".format(bucket,
                                                  spark_event_logs_key_prefix)

    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        input_data_uri = "s3://{}/spark/input/data.jsonl".format(bucket)
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
        spark_event_logs_s3_uri=spark_event_logs_s3_uri,
        wait=False,
    )
    processing_job = spark.latest_job

    s3_client = boto3.client("s3", region_name=region)

    file_size = 0
    latest_file_size = None
    updated_times_count = 0
    time_out = time.time() + 900

    while not processing_job_not_fail_or_complete(sagemaker_client,
                                                  processing_job.job_name):
        response = s3_client.list_objects(Bucket=bucket,
                                          Prefix=spark_event_logs_key_prefix)
        if "Contents" in response:
            # somehow when call list_objects the first file size is always 0, this for loop
            # is to skip that.
            for event_log_file in response["Contents"]:
                if event_log_file["Size"] != 0:
                    print("\n##### Latest file size is " +
                          str(event_log_file["Size"]))
                    latest_file_size = event_log_file["Size"]

        # update the file size if it increased
        if latest_file_size and latest_file_size > file_size:
            print("\n##### S3 file updated.")
            updated_times_count += 1
            file_size = latest_file_size

        if time.time() > time_out:
            raise RuntimeError("Timeout")

        time.sleep(20)

    # verify that spark event logs are periodically written to s3
    print("\n##### file_size {} updated_times_count {}".format(
        file_size, updated_times_count))
    assert file_size != 0

    # Commenting this assert because it's flaky.
    # assert updated_times_count > 1

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
def test_sagemaker_pyspark_sse_kms_s3(role, image_uri, sagemaker_session,
                                      region, sagemaker_client, account_id,
                                      partition):
    spark = PySparkProcessor(
        base_job_name="sm-spark-py",
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    # This test expected AWS managed s3 kms key to be present. The key will be in
    # KMS > AWS managed keys > aws/s3
    kms_key_id = None
    kms_client = sagemaker_session.boto_session.client("kms",
                                                       region_name=region)
    for alias in kms_client.list_aliases()["Aliases"]:
        if "s3" in alias["AliasName"]:
            kms_key_id = alias["TargetKeyId"]

    if not kms_key_id:
        raise ValueError(
            "AWS managed s3 kms key(alias: aws/s3) does not exist")

    bucket = sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    input_data_key = f"spark/input/sales/{timestamp}/data.jsonl"
    input_data_uri = f"s3://{bucket}/{input_data_key}"
    output_data_uri_prefix = f"spark/output/sales/{timestamp}"
    output_data_uri = f"s3://{bucket}/{output_data_uri_prefix}"
    s3_client = sagemaker_session.boto_session.client("s3", region_name=region)
    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        s3_client.put_object(Body=body,
                             Bucket=bucket,
                             Key=input_data_key,
                             ServerSideEncryption="aws:kms",
                             SSEKMSKeyId=kms_key_id)

    spark.run(
        submit_app=
        "test/resources/code/python/hello_py_spark/hello_py_spark_app.py",
        submit_py_files=[
            "test/resources/code/python/hello_py_spark/hello_py_spark_udfs.py"
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration={
            "Classification": "core-site",
            "Properties": {
                "fs.s3a.server-side-encryption-algorithm":
                "SSE-KMS",
                "fs.s3a.server-side-encryption.key":
                f"arn:{partition}:kms:{region}:{account_id}:key/{kms_key_id}",
            },
        },
    )
    processing_job = spark.latest_job
    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    s3_objects = s3_client.list_objects(
        Bucket=bucket, Prefix=output_data_uri_prefix)["Contents"]
    assert len(s3_objects) != 0
    for s3_object in s3_objects:
        object_metadata = s3_client.get_object(Bucket=bucket,
                                               Key=s3_object["Key"])
        assert object_metadata["ServerSideEncryption"] == "aws:kms"
        assert object_metadata[
            "SSEKMSKeyId"] == f"arn:{partition}:kms:{region}:{account_id}:key/{kms_key_id}"
sm = boto3.Session().client(service_name='sagemaker')
sagemaker_role = sagemaker.get_execution_role()

# ############################ #
# Pyspark Processor definition #
spark_processor = PySparkProcessor(
    base_job_name='spark-proc-name',
    framework_version='2.4',
    role=sagemaker_role,
    instance_count=1,
    instance_type='ml.r5.8xlarge',
    env={'AWS_DEFAULT_REGION': boto3.Session().region_name},
    max_runtime_in_seconds=1800)

configuration = [{
    "Classification": "spark-defaults",
    "Properties": {
        "spark.executor.memory": "200g",
        "spark.driver.memory": "200g",
        "spark.executor.cores": "20",
        "spark.cores.memmaxory": "20"
    }
}]

# #################################### #
# Launch Pyspark Processor with script #
proc = spark_processor.run(submit_app='script.py',
                           configuration=configuration,
                           wait=False)