def test_hook_min_args(self): good_args = self.parser.parse_args( required_args + ['--debug_hook_config', '{"S3OutputPath": "s3://fake-uri/"}']) response = _utils.create_training_job_request(vars(good_args)) self.assertEqual(response['DebugHookConfig']['S3OutputPath'], "s3://fake-uri/")
def test_empty_hyperparameters(self): hyperparameters_str = '{}' good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) response = _utils.create_training_job_request(vars(good_args)) self.assertEqual(response['HyperParameters'], {})
def test_first_party_algorithm(self): algorithm_name_args = self.parser.parse_args(required_args + ['--algorithm_name', 'first-algorithm']) # Should not throw an exception response = _utils.create_training_job_request(vars(algorithm_name_args)) self.assertIn('TrainingImage', response['AlgorithmSpecification']) self.assertNotIn('AlgorithmName', response['AlgorithmSpecification'])
def test_training_mode(self): required_vpc_args = self.parser.parse_args( required_args + ['--training_input_mode', 'Pipe']) response = _utils.create_training_job_request(vars(required_vpc_args)) self.assertEqual( response['AlgorithmSpecification']['TrainingInputMode'], 'Pipe')
def test_rule_max_args(self): good_args = self.parser.parse_args(required_args + [ '--debug_rule_config', '[{"InstanceType": "ml.m4.xlarge", "LocalPath": "/local/path/", "RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image", "RuleParameters": {"key1": "value1"}, "S3OutputPath": "s3://fake-uri/", "VolumeSizeInGB": 1}]' ]) response = _utils.create_training_job_request(vars(good_args)) self.assertEqual( response['DebugRuleConfigurations'][0]['InstanceType'], 'ml.m4.xlarge') self.assertEqual(response['DebugRuleConfigurations'][0]['LocalPath'], '/local/path/') self.assertEqual( response['DebugRuleConfigurations'][0]['RuleConfigurationName'], 'rule_name') self.assertEqual( response['DebugRuleConfigurations'][0]['RuleEvaluatorImage'], 'test-image') self.assertEqual( response['DebugRuleConfigurations'][0]['RuleParameters'], {"key1": "value1"}) self.assertEqual( response['DebugRuleConfigurations'][0]['S3OutputPath'], 's3://fake-uri/') self.assertEqual( response['DebugRuleConfigurations'][0]['VolumeSizeInGB'], 1)
def test_hook_max_args(self): good_args = self.parser.parse_args(required_args + [ '--debug_hook_config', '{"S3OutputPath": "s3://fake-uri/", "LocalPath": "/local/path/", "HookParameters": {"key": "value"}, "CollectionConfigurations": [{"CollectionName": "collection1", "CollectionParameters": {"key1": "value1"}}, {"CollectionName": "collection2", "CollectionParameters": {"key2": "value2", "key3": "value3"}}]}' ]) response = _utils.create_training_job_request(vars(good_args)) self.assertEqual(response['DebugHookConfig']['S3OutputPath'], "s3://fake-uri/") self.assertEqual(response['DebugHookConfig']['LocalPath'], "/local/path/") self.assertEqual(response['DebugHookConfig']['HookParameters'], {"key": "value"}) self.assertEqual( response['DebugHookConfig']['CollectionConfigurations'], [{ "CollectionName": "collection1", "CollectionParameters": { "key1": "value1" } }, { "CollectionName": "collection2", "CollectionParameters": { "key2": "value2", "key3": "value3" } }])
def test_empty_string(self): good_args = self.parser.parse_args(required_args + [ '--spot_instance', 'True', '--max_wait_time', '3600', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}' ]) response = _utils.create_training_job_request(vars(good_args)) test_utils.check_empty_string_values(response)
def test_vpc_configuration(self): required_vpc_args = self.parser.parse_args(required_args + ['--vpc_security_group_ids', 'sg1,sg2', '--vpc_subnets', 'subnet1,subnet2']) response = _utils.create_training_job_request(vars(required_vpc_args)) self.assertIn('VpcConfig', response) self.assertIn('sg1', response['VpcConfig']['SecurityGroupIds']) self.assertIn('sg2', response['VpcConfig']['SecurityGroupIds']) self.assertIn('subnet1', response['VpcConfig']['Subnets']) self.assertIn('subnet2', response['VpcConfig']['Subnets'])
def test_reasonable_required_args(self): response = _utils.create_training_job_request(vars(self.parser.parse_args(required_args))) # Ensure all of the optional arguments have reasonable default values self.assertFalse(response['EnableManagedSpotTraining']) self.assertDictEqual(response['HyperParameters'], {}) self.assertNotIn('VpcConfig', response) self.assertNotIn('MetricDefinitions', response) self.assertEqual(response['Tags'], []) self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'File') self.assertEqual(response['OutputDataConfig']['S3OutputPath'], 'test-path')
def test_spot_local_path(self): args = self.parser.parse_args(required_args + [ '--spot_instance', 'True', '--max_wait_time', '3600', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}' ]) response = _utils.create_training_job_request(vars(args)) self.assertEqual(response['CheckpointConfig']['S3Uri'], 's3://fake-uri/') self.assertEqual(response['CheckpointConfig']['LocalPath'], 'local-path')
def test_spot_good_args(self): good_args = self.parser.parse_args(required_args + [ '--spot_instance', 'True', '--max_wait_time', '3600', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}' ]) response = _utils.create_training_job_request(vars(good_args)) self.assertTrue(response['EnableManagedSpotTraining']) self.assertEqual(response['StoppingCondition']['MaxWaitTimeInSeconds'], 3600) self.assertEqual(response['CheckpointConfig']['S3Uri'], 's3://fake-uri/')
def test_rule_min_good_args(self): good_args = self.parser.parse_args(required_args + [ '--debug_rule_config', '[{"RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image"}]' ]) response = _utils.create_training_job_request(vars(good_args)) self.assertEqual( response['DebugRuleConfigurations'][0]['RuleConfigurationName'], 'rule_name') self.assertEqual( response['DebugRuleConfigurations'][0]['RuleEvaluatorImage'], 'test-image')
def test_valid_hyperparameters(self): hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}' good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) response = _utils.create_training_job_request(vars(good_args)) self.assertIn('hp1', response['HyperParameters']) self.assertIn('hp2', response['HyperParameters']) self.assertIn('hp3', response['HyperParameters']) self.assertEqual(response['HyperParameters']['hp1'], "val1") self.assertEqual(response['HyperParameters']['hp2'], "val2") self.assertEqual(response['HyperParameters']['hp3'], "val3")
def test_metric_definitions(self): metric_definition_args = self.parser.parse_args(required_args + ['--metric_definitions', '{"metric1": "regexval1", "metric2": "regexval2"}']) response = _utils.create_training_job_request(vars(metric_definition_args)) self.assertIn('MetricDefinitions', response['AlgorithmSpecification']) response_metric_definitions = response['AlgorithmSpecification']['MetricDefinitions'] self.assertEqual(response_metric_definitions, [{ 'Name': "metric1", 'Regex': "regexval1" }, { 'Name': "metric2", 'Regex': "regexval2" }])
def test_known_algorithm_value(self): # This passes an algorithm that is a known SageMaker algorithm name known_algorithm_args = required_args + ['--algorithm_name', 'seq2seq'] image_index = required_args.index('--image') # Cut out --image and it's associated value known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:] parsed_args = self.parser.parse_args(known_algorithm_args) # Patch get_image_uri _utils.get_image_uri = MagicMock() _utils.get_image_uri.return_value = "seq2seq-url" response = _utils.create_training_job_request(vars(parsed_args)) _utils.get_image_uri.assert_called_with('us-west-2', 'seq2seq') self.assertEqual(response['AlgorithmSpecification']['TrainingImage'], "seq2seq-url")
def test_unknown_algorithm(self): known_algorithm_args = required_args + ['--algorithm_name', 'unknown algorithm'] image_index = required_args.index('--image') # Cut out --image and it's associated value known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:] parsed_args = self.parser.parse_args(known_algorithm_args) # Patch get_image_uri _utils.get_image_uri = MagicMock() _utils.get_image_uri.return_value = "unknown-url" response = _utils.create_training_job_request(vars(parsed_args)) # Should just place the algorithm name in regardless _utils.get_image_uri.assert_not_called() self.assertEqual(response['AlgorithmSpecification']['AlgorithmName'], "unknown algorithm")
def test_tags(self): args = self.parser.parse_args(required_args + ['--tags', '{"key1": "val1", "key2": "val2"}']) response = _utils.create_training_job_request(vars(args)) self.assertIn({'Key': 'key1', 'Value': 'val1'}, response['Tags']) self.assertIn({'Key': 'key2', 'Value': 'val2'}, response['Tags'])
def test_spot_lesser_wait_time(self): args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3599', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}']) with self.assertRaises(Exception): _utils.create_training_job_request(vars(args))
def test_object_hyperparameters(self): hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}' invalid_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) with self.assertRaises(Exception): _utils.create_training_job_request(vars(invalid_args))