예제 #1
0
def main(argv=None):
    parser = create_parser()
    args = parser.parse_args(argv)

    logging.getLogger().setLevel(logging.INFO)
    client = _utils.get_sagemaker_client(args.region, args.endpoint_url)

    logging.info('Submitting Processing Job to SageMaker...')
    job_name = _utils.create_processing_job(client, vars(args))
    logging.info('Job request submitted. Waiting for completion...')

    try:
        _utils.wait_for_processing_job(client, job_name)
    except:
        raise
    finally:
        cw_client = _utils.get_cloudwatch_client(args.region)
        _utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs',
                                  job_name)

    outputs = _utils.get_processing_job_outputs(client, job_name)

    with open('/tmp/job_name.txt', 'w') as f:
        f.write(job_name)

    with open('/tmp/output_artifacts.txt', 'w') as f:
        f.write(json.dumps(outputs))

    logging.info('Job completed.')
예제 #2
0
  def test_sagemaker_exception_in_create_processing_job(self):
    mock_client = MagicMock()
    mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "create_processing_job")
    mock_client.create_processing_job.side_effect = mock_exception
    mock_args = self.parser.parse_args(required_args)

    with self.assertRaises(Exception):
      response = _utils.create_processing_job(mock_client, vars(mock_args))
예제 #3
0
  def test_create_processing_job(self):
    mock_client = MagicMock()
    mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])
    response = _utils.create_processing_job(mock_client, vars(mock_args))

    mock_client.create_processing_job.assert_called_once_with(
      AppSpecification={"ImageUri": "test-image"},
      Environment={},
      NetworkConfig={
        "EnableInterContainerTrafficEncryption": False,
        "EnableNetworkIsolation": True,
      },
      ProcessingInputs=[
        {
          "InputName": "dataset-input",
          "S3Input": {
            "S3Uri": "s3://my-bucket/dataset.csv",
            "LocalPath": "/opt/ml/processing/input",
            "S3DataType": "S3Prefix",
            "S3InputMode": "File"
          },
        }
      ],
      ProcessingJobName="test-job",
      ProcessingOutputConfig={
        "Outputs": [
          {
            "OutputName": "training-outputs",
            "S3Output": {
              "S3Uri": "s3://my-bucket/outputs/train.csv",
              "LocalPath": "/opt/ml/processing/output/train",
              "S3UploadMode": "Continuous"
            },
          }
        ]
      },
      ProcessingResources={
        "ClusterConfig": {
          "InstanceType": "ml.m4.xlarge",
          "InstanceCount": 1,
          "VolumeSizeInGB": 30,
        }
      },
      RoleArn="arn:aws:iam::123456789012:user/Development/product_1234/*",
      StoppingCondition={"MaxRuntimeInSeconds": 86400},
      Tags=[],
    )
    self.assertEqual(response, 'test-job')
예제 #4
0
def main(argv=None):
    parser = create_parser()
    args = parser.parse_args(argv)

    logging.getLogger().setLevel(logging.INFO)
    client = _utils.get_sagemaker_client(args.region,
                                         args.endpoint_url,
                                         assume_role_arn=args.assume_role)

    logging.info('Submitting Processing Job to SageMaker...')
    job_name = _utils.create_processing_job(client, vars(args))

    def signal_term_handler(signalNumber, frame):
        logging.info(f"Stopping Processing Job: {job_name}")
        _utils.stop_processing_job(client, job_name)
        logging.info(f"Processing Job: {job_name} request submitted to Stop")

    signal.signal(signal.SIGTERM, signal_term_handler)

    logging.info('Job request submitted. Waiting for completion...')

    try:
        _utils.wait_for_processing_job(client, job_name)
    except:
        raise
    finally:
        cw_client = _utils.get_cloudwatch_client(
            args.region, assume_role_arn=args.assume_role)
        _utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs',
                                  job_name)

    outputs = _utils.get_processing_job_outputs(client, job_name)

    _utils.write_output(args.job_name_output_path, job_name)
    _utils.write_output(args.output_artifacts_output_path,
                        outputs,
                        json_encode=True)

    logging.info('Job completed.')