def test_sagemaker_scala_jar_multinode(tag, role, image_uri, configuration,
                                       sagemaker_session, sagemaker_client):
    """Test SparkJarProcessor using Scala application jar with external runtime dependency jars staged by SDK"""
    spark = SparkJarProcessor(
        base_job_name="sm-spark-scala",
        framework_version=tag,
        image_uri=image_uri,
        role=role,
        instance_count=2,
        instance_type="ml.c5.xlarge",
        max_runtime_in_seconds=1200,
        sagemaker_session=sagemaker_session,
    )

    bucket = spark.sagemaker_session.default_bucket()
    with open("test/resources/data/files/data.jsonl") as data:
        body = data.read()
        input_data_uri = "s3://{}/spark/input/data.jsonl".format(bucket)
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session)
    output_data_uri = "s3://{}/spark/output/sales/{}".format(
        bucket,
        datetime.now().isoformat())

    scala_project_dir = "test/resources/code/scala/hello-scala-spark"
    spark.run(
        submit_app="{}/target/scala-2.11/hello-scala-spark_2.11-1.0.jar".
        format(scala_project_dir),
        submit_class="com.amazonaws.sagemaker.spark.test.HelloScalaSparkApp",
        submit_jars=[
            "{}/lib_managed/jars/org.json4s/json4s-native_2.11/json4s-native_2.11-3.6.9.jar"
            .format(scala_project_dir)
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
    )
    processing_job = spark.latest_job

    waiter = sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={
            "Delay": 15,
            "MaxAttempts": 60
        },
    )

    output_contents = S3Downloader.list(output_data_uri,
                                        sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0
Beispiel #2
0
def test_spark_jar_processor_run(
    mock_generate_current_job_name,
    mock_stage_submit_deps,
    mock_super_run,
    config,
    expected,
    sagemaker_session,
):
    mock_stage_submit_deps.return_value = (processing_input, "opt")
    mock_generate_current_job_name.return_value = "jobName"

    spark_jar_processor = SparkJarProcessor(
        base_job_name="sm-spark",
        role="AmazonSageMaker-ExecutionRole",
        framework_version="2.4",
        instance_count=1,
        instance_type="ml.c5.xlarge",
        image_uri=
        "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
        sagemaker_session=sagemaker_session,
    )

    if expected is ValueError:
        with pytest.raises(expected):
            spark_jar_processor.run(
                submit_app=config["submit_app"],
                submit_class=config["submit_class"],
                submit_jars=config["files"],
                submit_files=config["files"],
                inputs=config["inputs"],
            )
    else:
        spark_jar_processor.run(
            submit_app=config["submit_app"],
            submit_class=config["submit_class"],
            submit_jars=config["files"],
            submit_files=config["files"],
            inputs=config["inputs"],
        )

        mock_super_run.assert_called_with(
            submit_app=config["submit_app"],
            inputs=expected,
            outputs=None,
            arguments=None,
            wait=True,
            logs=True,
            job_name="jobName",
            experiment_config=None,
            kms_key=None,
        )