예제 #1
0
def test_spark_jar_processor_get_run_args(
    mock_generate_current_job_name,
    mock_stage_submit_deps,
    mock_super_get_run_args,
    config,
    expected,
    sagemaker_session,
):
    mock_stage_submit_deps.return_value = (processing_input, "opt")
    mock_generate_current_job_name.return_value = "jobName"

    spark_jar_processor = SparkJarProcessor(
        base_job_name="sm-spark",
        role="AmazonSageMaker-ExecutionRole",
        framework_version="2.4",
        instance_count=1,
        instance_type="ml.c5.xlarge",
        image_uri=
        "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
        sagemaker_session=sagemaker_session,
    )

    if expected is ValueError:
        with pytest.raises(expected):
            spark_jar_processor.get_run_args(
                submit_app=config["submit_app"],
                submit_class=config["submit_class"],
                submit_jars=config["files"],
                submit_files=config["files"],
                inputs=config["inputs"],
                arguments=config["arguments"],
            )
    else:
        spark_jar_processor.get_run_args(
            submit_app=config["submit_app"],
            submit_class=config["submit_class"],
            submit_jars=config["files"],
            submit_files=config["files"],
            inputs=config["inputs"],
            arguments=config["arguments"],
        )

        mock_super_get_run_args.assert_called_with(
            code=config["submit_app"],
            inputs=expected,
            outputs=None,
            arguments=config["arguments"],
        )
def test_one_step_sparkjar_processing_pipeline(
    sagemaker_session,
    role,
    cpu_instance_type,
    pipeline_name,
    region_name,
    configuration,
    build_jar,
):
    instance_count = ParameterInteger(name="InstanceCount", default_value=2)
    cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
    spark_path = os.path.join(DATA_DIR, "spark")

    spark_jar_processor = SparkJarProcessor(
        role=role,
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="2.4",
    )
    bucket = spark_jar_processor.sagemaker_session.default_bucket()
    with open(os.path.join(spark_path, "files", "data.jsonl")) as data:
        body = data.read()
        input_data_uri = f"s3://{bucket}/spark/input/data.jsonl"
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=input_data_uri,
            sagemaker_session=sagemaker_session,
        )
    output_data_uri = f"s3://{bucket}/spark/output/sales/{datetime.now().isoformat()}"

    java_project_dir = os.path.join(spark_path, "code", "java",
                                    "hello-java-spark")
    spark_run_args = spark_jar_processor.get_run_args(
        submit_app=f"{java_project_dir}/hello-spark-java.jar",
        submit_class="com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
    )

    step_pyspark = ProcessingStep(
        name="sparkjar-process",
        processor=spark_jar_processor,
        inputs=spark_run_args.inputs,
        outputs=spark_run_args.outputs,
        job_arguments=spark_run_args.arguments,
        code=spark_run_args.code,
        cache_config=cache_config,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[instance_count],
        steps=[step_pyspark],
        sagemaker_session=sagemaker_session,
    )

    try:
        # NOTE: We should exercise the case when role used in the pipeline execution is
        # different than that required of the steps in the pipeline itself. The role in
        # the pipeline definition needs to create training and processing jobs and other
        # sagemaker entities. However, the jobs created in the steps themselves execute
        # under a potentially different role, often requiring access to S3 and other
        # artifacts not required to during creation of the jobs in the pipeline steps.
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            create_arn,
        )

        pipeline.parameters = [
            ParameterInteger(name="InstanceCount", default_value=1)
        ]
        response = pipeline.update(role)
        update_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            update_arn,
        )

        execution = pipeline.start(parameters={})
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
            execution.arn,
        )

        response = execution.describe()
        assert response["PipelineArn"] == create_arn

        # Check CacheConfig
        response = json.loads(
            pipeline.describe()
            ["PipelineDefinition"])["Steps"][0]["CacheConfig"]
        assert response["Enabled"] == cache_config.enable_caching
        assert response["ExpireAfter"] == cache_config.expire_after

        try:
            execution.wait(delay=30, max_attempts=3)
        except WaiterError:
            pass
        execution_steps = execution.list_steps()
        assert len(execution_steps) == 1
        assert execution_steps[0]["StepName"] == "sparkjar-process"
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass