示例#1
0
 def test_hook_min_args(self):
     good_args = self.parser.parse_args(
         required_args +
         ['--debug_hook_config', '{"S3OutputPath": "s3://fake-uri/"}'])
     response = _utils.create_training_job_request(vars(good_args))
     self.assertEqual(response['DebugHookConfig']['S3OutputPath'],
                      "s3://fake-uri/")
示例#2
0
  def test_empty_hyperparameters(self):
    hyperparameters_str = '{}'

    good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
    response = _utils.create_training_job_request(vars(good_args))

    self.assertEqual(response['HyperParameters'], {})
示例#3
0
  def test_first_party_algorithm(self):
    algorithm_name_args = self.parser.parse_args(required_args + ['--algorithm_name', 'first-algorithm'])

    # Should not throw an exception
    response = _utils.create_training_job_request(vars(algorithm_name_args))
    self.assertIn('TrainingImage', response['AlgorithmSpecification'])
    self.assertNotIn('AlgorithmName', response['AlgorithmSpecification'])
示例#4
0
    def test_training_mode(self):
        required_vpc_args = self.parser.parse_args(
            required_args + ['--training_input_mode', 'Pipe'])
        response = _utils.create_training_job_request(vars(required_vpc_args))

        self.assertEqual(
            response['AlgorithmSpecification']['TrainingInputMode'], 'Pipe')
示例#5
0
 def test_rule_max_args(self):
     good_args = self.parser.parse_args(required_args + [
         '--debug_rule_config',
         '[{"InstanceType": "ml.m4.xlarge", "LocalPath": "/local/path/", "RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image", "RuleParameters": {"key1": "value1"}, "S3OutputPath": "s3://fake-uri/", "VolumeSizeInGB": 1}]'
     ])
     response = _utils.create_training_job_request(vars(good_args))
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['InstanceType'],
         'ml.m4.xlarge')
     self.assertEqual(response['DebugRuleConfigurations'][0]['LocalPath'],
                      '/local/path/')
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['RuleConfigurationName'],
         'rule_name')
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['RuleEvaluatorImage'],
         'test-image')
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['RuleParameters'],
         {"key1": "value1"})
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['S3OutputPath'],
         's3://fake-uri/')
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['VolumeSizeInGB'], 1)
示例#6
0
 def test_hook_max_args(self):
     good_args = self.parser.parse_args(required_args + [
         '--debug_hook_config',
         '{"S3OutputPath": "s3://fake-uri/", "LocalPath": "/local/path/", "HookParameters": {"key": "value"}, "CollectionConfigurations": [{"CollectionName": "collection1", "CollectionParameters": {"key1": "value1"}}, {"CollectionName": "collection2", "CollectionParameters": {"key2": "value2", "key3": "value3"}}]}'
     ])
     response = _utils.create_training_job_request(vars(good_args))
     self.assertEqual(response['DebugHookConfig']['S3OutputPath'],
                      "s3://fake-uri/")
     self.assertEqual(response['DebugHookConfig']['LocalPath'],
                      "/local/path/")
     self.assertEqual(response['DebugHookConfig']['HookParameters'],
                      {"key": "value"})
     self.assertEqual(
         response['DebugHookConfig']['CollectionConfigurations'],
         [{
             "CollectionName": "collection1",
             "CollectionParameters": {
                 "key1": "value1"
             }
         }, {
             "CollectionName": "collection2",
             "CollectionParameters": {
                 "key2": "value2",
                 "key3": "value3"
             }
         }])
示例#7
0
 def test_empty_string(self):
     good_args = self.parser.parse_args(required_args + [
         '--spot_instance', 'True', '--max_wait_time', '3600',
         '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}'
     ])
     response = _utils.create_training_job_request(vars(good_args))
     test_utils.check_empty_string_values(response)
示例#8
0
  def test_vpc_configuration(self):
    required_vpc_args = self.parser.parse_args(required_args + ['--vpc_security_group_ids', 'sg1,sg2', '--vpc_subnets', 'subnet1,subnet2'])
    response = _utils.create_training_job_request(vars(required_vpc_args))

    self.assertIn('VpcConfig', response)
    self.assertIn('sg1', response['VpcConfig']['SecurityGroupIds'])
    self.assertIn('sg2', response['VpcConfig']['SecurityGroupIds'])
    self.assertIn('subnet1', response['VpcConfig']['Subnets'])
    self.assertIn('subnet2', response['VpcConfig']['Subnets'])
示例#9
0
  def test_reasonable_required_args(self):
    response = _utils.create_training_job_request(vars(self.parser.parse_args(required_args)))

    # Ensure all of the optional arguments have reasonable default values
    self.assertFalse(response['EnableManagedSpotTraining'])
    self.assertDictEqual(response['HyperParameters'], {})
    self.assertNotIn('VpcConfig', response)
    self.assertNotIn('MetricDefinitions', response)
    self.assertEqual(response['Tags'], [])
    self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'File')
    self.assertEqual(response['OutputDataConfig']['S3OutputPath'], 'test-path')
示例#10
0
 def test_spot_local_path(self):
     args = self.parser.parse_args(required_args + [
         '--spot_instance', 'True', '--max_wait_time', '3600',
         '--checkpoint_config',
         '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}'
     ])
     response = _utils.create_training_job_request(vars(args))
     self.assertEqual(response['CheckpointConfig']['S3Uri'],
                      's3://fake-uri/')
     self.assertEqual(response['CheckpointConfig']['LocalPath'],
                      'local-path')
示例#11
0
 def test_spot_good_args(self):
     good_args = self.parser.parse_args(required_args + [
         '--spot_instance', 'True', '--max_wait_time', '3600',
         '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}'
     ])
     response = _utils.create_training_job_request(vars(good_args))
     self.assertTrue(response['EnableManagedSpotTraining'])
     self.assertEqual(response['StoppingCondition']['MaxWaitTimeInSeconds'],
                      3600)
     self.assertEqual(response['CheckpointConfig']['S3Uri'],
                      's3://fake-uri/')
示例#12
0
 def test_rule_min_good_args(self):
     good_args = self.parser.parse_args(required_args + [
         '--debug_rule_config',
         '[{"RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image"}]'
     ])
     response = _utils.create_training_job_request(vars(good_args))
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['RuleConfigurationName'],
         'rule_name')
     self.assertEqual(
         response['DebugRuleConfigurations'][0]['RuleEvaluatorImage'],
         'test-image')
示例#13
0
  def test_valid_hyperparameters(self):
    hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}'

    good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
    response = _utils.create_training_job_request(vars(good_args))

    self.assertIn('hp1', response['HyperParameters'])
    self.assertIn('hp2', response['HyperParameters'])
    self.assertIn('hp3', response['HyperParameters'])
    self.assertEqual(response['HyperParameters']['hp1'], "val1")
    self.assertEqual(response['HyperParameters']['hp2'], "val2")
    self.assertEqual(response['HyperParameters']['hp3'], "val3")
示例#14
0
  def test_metric_definitions(self):
    metric_definition_args = self.parser.parse_args(required_args + ['--metric_definitions', '{"metric1": "regexval1", "metric2": "regexval2"}'])
    response = _utils.create_training_job_request(vars(metric_definition_args))

    self.assertIn('MetricDefinitions', response['AlgorithmSpecification'])
    response_metric_definitions = response['AlgorithmSpecification']['MetricDefinitions']

    self.assertEqual(response_metric_definitions, [{
      'Name': "metric1",
      'Regex': "regexval1"
    }, {
      'Name': "metric2",
      'Regex': "regexval2"
    }])
示例#15
0
  def test_known_algorithm_value(self):
    # This passes an algorithm that is a known SageMaker algorithm name
    known_algorithm_args = required_args + ['--algorithm_name', 'seq2seq']
    image_index = required_args.index('--image')
    # Cut out --image and it's associated value
    known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:]

    parsed_args = self.parser.parse_args(known_algorithm_args)

    # Patch get_image_uri
    _utils.get_image_uri = MagicMock()
    _utils.get_image_uri.return_value = "seq2seq-url"

    response = _utils.create_training_job_request(vars(parsed_args))

    _utils.get_image_uri.assert_called_with('us-west-2', 'seq2seq')
    self.assertEqual(response['AlgorithmSpecification']['TrainingImage'], "seq2seq-url")
示例#16
0
  def test_unknown_algorithm(self):
    known_algorithm_args = required_args + ['--algorithm_name', 'unknown algorithm']
    image_index = required_args.index('--image')
    # Cut out --image and it's associated value
    known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:]

    parsed_args = self.parser.parse_args(known_algorithm_args)

    # Patch get_image_uri
    _utils.get_image_uri = MagicMock()
    _utils.get_image_uri.return_value = "unknown-url"

    response = _utils.create_training_job_request(vars(parsed_args))

    # Should just place the algorithm name in regardless
    _utils.get_image_uri.assert_not_called()
    self.assertEqual(response['AlgorithmSpecification']['AlgorithmName'], "unknown algorithm")
示例#17
0
 def test_tags(self):
   args = self.parser.parse_args(required_args + ['--tags', '{"key1": "val1", "key2": "val2"}'])
   response = _utils.create_training_job_request(vars(args))
   self.assertIn({'Key': 'key1', 'Value': 'val1'}, response['Tags'])
   self.assertIn({'Key': 'key2', 'Value': 'val2'}, response['Tags'])
示例#18
0
 def test_spot_lesser_wait_time(self):
   args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3599', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}'])
   with self.assertRaises(Exception):
     _utils.create_training_job_request(vars(args))
示例#19
0
  def test_object_hyperparameters(self):
    hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}'

    invalid_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
    with self.assertRaises(Exception):
      _utils.create_training_job_request(vars(invalid_args))