def test_spark_jar_processor_get_run_args( mock_generate_current_job_name, mock_stage_submit_deps, mock_super_get_run_args, config, expected, sagemaker_session, ): mock_stage_submit_deps.return_value = (processing_input, "opt") mock_generate_current_job_name.return_value = "jobName" spark_jar_processor = SparkJarProcessor( base_job_name="sm-spark", role="AmazonSageMaker-ExecutionRole", framework_version="2.4", instance_count=1, instance_type="ml.c5.xlarge", image_uri= "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1", sagemaker_session=sagemaker_session, ) if expected is ValueError: with pytest.raises(expected): spark_jar_processor.get_run_args( submit_app=config["submit_app"], submit_class=config["submit_class"], submit_jars=config["files"], submit_files=config["files"], inputs=config["inputs"], arguments=config["arguments"], ) else: spark_jar_processor.get_run_args( submit_app=config["submit_app"], submit_class=config["submit_class"], submit_jars=config["files"], submit_files=config["files"], inputs=config["inputs"], arguments=config["arguments"], ) mock_super_get_run_args.assert_called_with( code=config["submit_app"], inputs=expected, outputs=None, arguments=config["arguments"], )
def test_one_step_sparkjar_processing_pipeline( sagemaker_session, role, cpu_instance_type, pipeline_name, region_name, configuration, build_jar, ): instance_count = ParameterInteger(name="InstanceCount", default_value=2) cache_config = CacheConfig(enable_caching=True, expire_after="T30m") spark_path = os.path.join(DATA_DIR, "spark") spark_jar_processor = SparkJarProcessor( role=role, instance_count=2, instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, framework_version="2.4", ) bucket = spark_jar_processor.sagemaker_session.default_bucket() with open(os.path.join(spark_path, "files", "data.jsonl")) as data: body = data.read() input_data_uri = f"s3://{bucket}/spark/input/data.jsonl" S3Uploader.upload_string_as_file_body( body=body, desired_s3_uri=input_data_uri, sagemaker_session=sagemaker_session, ) output_data_uri = f"s3://{bucket}/spark/output/sales/{datetime.now().isoformat()}" java_project_dir = os.path.join(spark_path, "code", "java", "hello-java-spark") spark_run_args = spark_jar_processor.get_run_args( submit_app=f"{java_project_dir}/hello-spark-java.jar", submit_class="com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", arguments=["--input", input_data_uri, "--output", output_data_uri], configuration=configuration, ) step_pyspark = ProcessingStep( name="sparkjar-process", processor=spark_jar_processor, inputs=spark_run_args.inputs, outputs=spark_run_args.outputs, job_arguments=spark_run_args.arguments, code=spark_run_args.code, cache_config=cache_config, ) pipeline = Pipeline( name=pipeline_name, parameters=[instance_count], steps=[step_pyspark], sagemaker_session=sagemaker_session, ) try: # NOTE: We should exercise the case when role used in the pipeline execution is # different than that required of the steps in the pipeline itself. The role in # the pipeline definition needs to create training and processing jobs and other # sagemaker entities. However, the jobs created in the steps themselves execute # under a potentially different role, often requiring access to S3 and other # artifacts not required to during creation of the jobs in the pipeline steps. response = pipeline.create(role) create_arn = response["PipelineArn"] assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) pipeline.parameters = [ ParameterInteger(name="InstanceCount", default_value=1) ] response = pipeline.update(role) update_arn = response["PipelineArn"] assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", update_arn, ) execution = pipeline.start(parameters={}) assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", execution.arn, ) response = execution.describe() assert response["PipelineArn"] == create_arn # Check CacheConfig response = json.loads( pipeline.describe() ["PipelineDefinition"])["Steps"][0]["CacheConfig"] assert response["Enabled"] == cache_config.enable_caching assert response["ExpireAfter"] == cache_config.expire_after try: execution.wait(delay=30, max_attempts=3) except WaiterError: pass execution_steps = execution.list_steps() assert len(execution_steps) == 1 assert execution_steps[0]["StepName"] == "sparkjar-process" finally: try: pipeline.delete() except Exception: pass