def main(argv=None): parser = create_parser() args = parser.parse_args(argv) logging.getLogger().setLevel(logging.INFO) client = _utils.get_sagemaker_client(args.region, args.endpoint_url) logging.info('Submitting Processing Job to SageMaker...') job_name = _utils.create_processing_job(client, vars(args)) logging.info('Job request submitted. Waiting for completion...') try: _utils.wait_for_processing_job(client, job_name) except: raise finally: cw_client = _utils.get_cloudwatch_client(args.region) _utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs', job_name) outputs = _utils.get_processing_job_outputs(client, job_name) with open('/tmp/job_name.txt', 'w') as f: f.write(job_name) with open('/tmp/output_artifacts.txt', 'w') as f: f.write(json.dumps(outputs)) logging.info('Job completed.')
def test_sagemaker_exception_in_create_processing_job(self): mock_client = MagicMock() mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "create_processing_job") mock_client.create_processing_job.side_effect = mock_exception mock_args = self.parser.parse_args(required_args) with self.assertRaises(Exception): response = _utils.create_processing_job(mock_client, vars(mock_args))
def test_create_processing_job(self): mock_client = MagicMock() mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job']) response = _utils.create_processing_job(mock_client, vars(mock_args)) mock_client.create_processing_job.assert_called_once_with( AppSpecification={"ImageUri": "test-image"}, Environment={}, NetworkConfig={ "EnableInterContainerTrafficEncryption": False, "EnableNetworkIsolation": True, }, ProcessingInputs=[ { "InputName": "dataset-input", "S3Input": { "S3Uri": "s3://my-bucket/dataset.csv", "LocalPath": "/opt/ml/processing/input", "S3DataType": "S3Prefix", "S3InputMode": "File" }, } ], ProcessingJobName="test-job", ProcessingOutputConfig={ "Outputs": [ { "OutputName": "training-outputs", "S3Output": { "S3Uri": "s3://my-bucket/outputs/train.csv", "LocalPath": "/opt/ml/processing/output/train", "S3UploadMode": "Continuous" }, } ] }, ProcessingResources={ "ClusterConfig": { "InstanceType": "ml.m4.xlarge", "InstanceCount": 1, "VolumeSizeInGB": 30, } }, RoleArn="arn:aws:iam::123456789012:user/Development/product_1234/*", StoppingCondition={"MaxRuntimeInSeconds": 86400}, Tags=[], ) self.assertEqual(response, 'test-job')
def main(argv=None): parser = create_parser() args = parser.parse_args(argv) logging.getLogger().setLevel(logging.INFO) client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role) logging.info('Submitting Processing Job to SageMaker...') job_name = _utils.create_processing_job(client, vars(args)) def signal_term_handler(signalNumber, frame): logging.info(f"Stopping Processing Job: {job_name}") _utils.stop_processing_job(client, job_name) logging.info(f"Processing Job: {job_name} request submitted to Stop") signal.signal(signal.SIGTERM, signal_term_handler) logging.info('Job request submitted. Waiting for completion...') try: _utils.wait_for_processing_job(client, job_name) except: raise finally: cw_client = _utils.get_cloudwatch_client( args.region, assume_role_arn=args.assume_role) _utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs', job_name) outputs = _utils.get_processing_job_outputs(client, job_name) _utils.write_output(args.job_name_output_path, job_name) _utils.write_output(args.output_artifacts_output_path, outputs, json_encode=True) logging.info('Job completed.')