示例#1
0
def main(argv=None):
    parser = argparse.ArgumentParser(description='SageMaker Training Job')
    parser.add_argument('--region',
                        type=str,
                        help='The region where the cluster launches.')
    parser.add_argument(
        '--image',
        type=str,
        help=
        'The Amazon EC2 Container Registry (Amazon ECR) path where inference code is stored.'
    )
    parser.add_argument('--model_artifact_url',
                        type=str,
                        help='S3 model artifacts url')
    parser.add_argument('--model_name',
                        type=str,
                        help='The name of the new model.')
    parser.add_argument(
        '--role',
        type=str,
        help=
        'The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.'
    )
    args = parser.parse_args()

    logging.getLogger().setLevel(logging.INFO)
    client = _utils.get_client(args.region)

    logging.info('Submitting model creation request to SageMaker...')
    _utils.create_model(client, args.model_artifact_url, args.model_name,
                        args.image, args.role)

    logging.info('Model creation completed.')
    with open('/tmp/model_name.txt', 'w') as f:
        f.write(args.model_name)
示例#2
0
  def test_sagemaker_exception_in_create_model(self):
    mock_client = MagicMock()
    mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "create_model")
    mock_client.create_model.side_effect = mock_exception
    mock_args = self.parser.parse_args(required_args)

    with self.assertRaises(Exception):
      _utils.create_model(mock_client, vars(mock_args))
示例#3
0
def main(argv=None):
    parser = create_parser()
    args = parser.parse_args()

    logging.getLogger().setLevel(logging.INFO)
    client = _utils.get_sagemaker_client(args.region, args.endpoint_url)

    logging.info('Submitting model creation request to SageMaker...')
    _utils.create_model(client, vars(args))

    logging.info('Model creation completed.')
    with open('/tmp/model_name.txt', 'w') as f:
        f.write(args.model_name)
示例#4
0
def main(argv=None):
  parser = create_parser()
  args = parser.parse_args(argv)

  logging.getLogger().setLevel(logging.INFO)
  client = _utils.get_sagemaker_client(args.region, args.endpoint_url)

  logging.info('Submitting model creation request to SageMaker...')
  _utils.create_model(client, vars(args))

  logging.info('Model creation completed.')

  _utils.write_output(args.model_name_output_path, args.model_name)
示例#5
0
  def test_create_model(self):
    mock_client = MagicMock()
    mock_args = self.parser.parse_args(required_args)
    response = _utils.create_model(mock_client, vars(mock_args))

    mock_client.create_model.assert_called_once_with(
      EnableNetworkIsolation=True,
      ExecutionRoleArn='arn:aws:iam::123456789012:user/Development/product_1234/*',
      ModelName='model_test',
      PrimaryContainer={'Image': 'test-image', 'ModelDataUrl': 's3://fake-bucket/model_artifact', 'Environment': {}},
      Tags=[]
    )
示例#6
0
def main(argv=None):
    parser = argparse.ArgumentParser(description='SageMaker Training Job')
    parser.add_argument('--region',
                        type=str.strip,
                        required=True,
                        help='The region where the cluster launches.')
    parser.add_argument('--model_name',
                        type=str.strip,
                        required=True,
                        help='The name of the new model.')
    parser.add_argument(
        '--role',
        type=str.strip,
        required=True,
        help=
        'The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.'
    )
    parser.add_argument(
        '--container_host_name',
        type=str.strip,
        required=False,
        help=
        'When a ContainerDefinition is part of an inference pipeline, this value uniquely identifies the container for the purposes of logging and metrics.',
        default='')
    parser.add_argument(
        '--image',
        type=str.strip,
        required=False,
        help=
        'The Amazon EC2 Container Registry (Amazon ECR) path where inference code is stored.',
        default='')
    parser.add_argument(
        '--model_artifact_url',
        type=str.strip,
        required=False,
        help='S3 path where Amazon SageMaker to store the model artifacts.',
        default='')
    parser.add_argument(
        '--environment',
        type=_utils.str_to_json_dict,
        required=False,
        help=
        'The dictionary of the environment variables to set in the Docker container. Up to 16 key-value entries in the map.',
        default='{}')
    parser.add_argument(
        '--model_package',
        type=str.strip,
        required=False,
        help=
        'The name or Amazon Resource Name (ARN) of the model package to use to create the model.',
        default='')
    parser.add_argument(
        '--secondary_containers',
        type=_utils.str_to_json_list,
        required=False,
        help=
        'A list of dicts that specifies the additional containers in the inference pipeline.',
        default='{}')
    parser.add_argument(
        '--vpc_security_group_ids',
        type=str.strip,
        required=False,
        help='The VPC security group IDs, in the form sg-xxxxxxxx.',
        default='')
    parser.add_argument(
        '--vpc_subnets',
        type=str.strip,
        required=False,
        help=
        'The ID of the subnets in the VPC to which you want to connect your hpo job.',
        default='')
    parser.add_argument('--network_isolation',
                        type=_utils.str_to_bool,
                        required=False,
                        help='Isolates the training container.',
                        default=True)
    parser.add_argument(
        '--tags',
        type=_utils.str_to_json_dict,
        required=False,
        help='An array of key-value pairs, to categorize AWS resources.',
        default='{}')
    args = parser.parse_args()

    logging.getLogger().setLevel(logging.INFO)
    client = _utils.get_client(args.region)

    logging.info('Submitting model creation request to SageMaker...')
    _utils.create_model(client, vars(args))

    logging.info('Model creation completed.')
    with open('/tmp/model_name.txt', 'w') as f:
        f.write(args.model_name)