コード例 #1
0
    def test_do_sets_name(self):
        named_spec = SageMakerTrainingSpec(self.REQUIRED_ARGS +
                                           ["--job_name", "job-name"])
        unnamed_spec = SageMakerTrainingSpec(self.REQUIRED_ARGS)

        self.component.Do(named_spec)
        self.assertEqual("job-name", self.component._training_job_name)

        with patch(
                "train.src.sagemaker_training_component.SageMakerComponent._generate_unique_timestamped_id",
                MagicMock(return_value="unique"),
        ):
            self.component.Do(unnamed_spec)
            self.assertEqual("unique", self.component._training_job_name)
コード例 #2
0
 def test_hook_max_args(self):
     spec = SageMakerTrainingSpec(self.REQUIRED_ARGS + [
         "--debug_hook_config",
         '{"S3OutputPath": "s3://fake-uri/", "LocalPath": "/local/path/", "HookParameters": {"key": "value"}, "CollectionConfigurations": [{"CollectionName": "collection1", "CollectionParameters": {"key1": "value1"}}, {"CollectionName": "collection2", "CollectionParameters": {"key2": "value2", "key3": "value3"}}]}',
     ])
     response = self.component._create_job_request(spec.inputs,
                                                   spec.outputs)
     self.assertEqual(response["DebugHookConfig"]["S3OutputPath"],
                      "s3://fake-uri/")
     self.assertEqual(response["DebugHookConfig"]["LocalPath"],
                      "/local/path/")
     self.assertEqual(response["DebugHookConfig"]["HookParameters"],
                      {"key": "value"})
     self.assertEqual(
         response["DebugHookConfig"]["CollectionConfigurations"],
         [
             {
                 "CollectionName": "collection1",
                 "CollectionParameters": {
                     "key1": "value1"
                 },
             },
             {
                 "CollectionName": "collection2",
                 "CollectionParameters": {
                     "key2": "value2",
                     "key3": "value3"
                 },
             },
         ],
     )
コード例 #3
0
    def test_metric_definitions(self):
        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS + [
            "--metric_definitions",
            '{"metric1": "regexval1", "metric2": "regexval2"}',
        ])
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)

        self.assertIn("MetricDefinitions", response["AlgorithmSpecification"])
        response_metric_definitions = response["AlgorithmSpecification"][
            "MetricDefinitions"]

        self.assertEqual(
            response_metric_definitions,
            [
                {
                    "Name": "metric1",
                    "Regex": "regexval1"
                },
                {
                    "Name": "metric2",
                    "Regex": "regexval2"
                },
            ],
        )
コード例 #4
0
    def test_object_hyperparameters(self):
        hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}'

        spec = SageMakerTrainingSpec(
            self.REQUIRED_ARGS + ["--hyperparameters", hyperparameters_str])
        with self.assertRaises(Exception):
            self.component._create_job_request(spec.inputs, spec.outputs)
コード例 #5
0
 def test_rule_max_args(self):
     spec = SageMakerTrainingSpec(self.REQUIRED_ARGS + [
         "--debug_rule_config",
         '[{"InstanceType": "ml.m4.xlarge", "LocalPath": "/local/path/", "RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image", "RuleParameters": {"key1": "value1"}, "S3OutputPath": "s3://fake-uri/", "VolumeSizeInGB": 1}]',
     ])
     response = self.component._create_job_request(spec.inputs,
                                                   spec.outputs)
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["InstanceType"],
         "ml.m4.xlarge")
     self.assertEqual(response["DebugRuleConfigurations"][0]["LocalPath"],
                      "/local/path/")
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["RuleConfigurationName"],
         "rule_name")
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["RuleEvaluatorImage"],
         "test-image")
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["RuleParameters"],
         {"key1": "value1"})
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["S3OutputPath"],
         "s3://fake-uri/")
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["VolumeSizeInGB"], 1)
コード例 #6
0
    def test_training_mode(self):
        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS +
                                     ["--training_input_mode", "Pipe"])
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)

        self.assertEqual(
            response["AlgorithmSpecification"]["TrainingInputMode"], "Pipe")
コード例 #7
0
 def test_hook_min_args(self):
     spec = SageMakerTrainingSpec(
         self.REQUIRED_ARGS +
         ["--debug_hook_config", '{"S3OutputPath": "s3://fake-uri/"}'])
     response = self.component._create_job_request(spec.inputs,
                                                   spec.outputs)
     self.assertEqual(response["DebugHookConfig"]["S3OutputPath"],
                      "s3://fake-uri/")
コード例 #8
0
    def test_no_channels(self):
        no_channels_args = self.REQUIRED_ARGS.copy()
        channels_index = self.REQUIRED_ARGS.index("--channels")
        # Replace the value after the flag with an empty list
        no_channels_args[channels_index + 1] = "[]"
        spec = SageMakerTrainingSpec(no_channels_args)

        with self.assertRaises(Exception):
            self.component._create_job_request(spec.inputs, spec.outputs)
コード例 #9
0
    def test_first_party_algorithm(self):
        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS +
                                     ["--algorithm_name", "first-algorithm"])

        # Should not throw an exception
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)
        self.assertIn("TrainingImage", response["AlgorithmSpecification"])
        self.assertNotIn("AlgorithmName", response["AlgorithmSpecification"])
コード例 #10
0
    def test_empty_hyperparameters(self):
        hyperparameters_str = "{}"

        spec = SageMakerTrainingSpec(
            self.REQUIRED_ARGS + ["--hyperparameters", hyperparameters_str])
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)

        self.assertEqual(response["HyperParameters"], {})
コード例 #11
0
    def test_after_job_completed(self):
        self.component._get_model_artifacts_from_job = MagicMock(
            return_value="model")
        self.component._get_image_from_job = MagicMock(return_value="image")

        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS)

        self.component._after_job_complete({}, {}, spec.inputs, spec.outputs)

        self.assertEqual(spec.outputs.job_name, "test-job")
        self.assertEqual(spec.outputs.model_artifact_url, "model")
        self.assertEqual(spec.outputs.training_image, "image")
コード例 #12
0
    def test_no_defined_image(self):
        # Pass the image to pass the parser
        no_image_args = self.REQUIRED_ARGS.copy()
        image_index = no_image_args.index("--image")
        # Cut out --image and it's associated value
        no_image_args = no_image_args[:image_index] + no_image_args[
            image_index + 2:]

        spec = SageMakerTrainingSpec(no_image_args)

        with self.assertRaises(Exception):
            self.component._create_job_request(spec.inputs, spec.outputs)
コード例 #13
0
    def test_create_training_job(self):
        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS)
        request = self.component._create_job_request(spec.inputs, spec.outputs)

        self.assertEqual(
            request,
            {
                "AlgorithmSpecification": {
                    "TrainingImage": "test-image",
                    "TrainingInputMode": "File",
                },
                "EnableInterContainerTrafficEncryption":
                False,
                "EnableManagedSpotTraining":
                False,
                "EnableNetworkIsolation":
                True,
                "HyperParameters": {},
                "InputDataConfig": [{
                    "ChannelName": "train",
                    "DataSource": {
                        "S3DataSource": {
                            "S3Uri": "s3://fake-bucket/data",
                            "S3DataType": "S3Prefix",
                            "S3DataDistributionType": "FullyReplicated",
                        }
                    },
                    "ContentType": "",
                    "CompressionType": "None",
                    "RecordWrapperType": "None",
                    "InputMode": "File",
                }],
                "OutputDataConfig": {
                    "KmsKeyId": "",
                    "S3OutputPath": "test-path"
                },
                "ResourceConfig": {
                    "InstanceType": "ml.m4.xlarge",
                    "InstanceCount": 1,
                    "VolumeSizeInGB": 50,
                    "VolumeKmsKeyId": "",
                },
                "RoleArn":
                "arn:aws:iam::123456789012:user/Development/product_1234/*",
                "StoppingCondition": {
                    "MaxRuntimeInSeconds": 3600
                },
                "Tags": [],
                "TrainingJobName":
                "test-job",
            },
        )
コード例 #14
0
 def test_rule_min_good_args(self):
     spec = SageMakerTrainingSpec(self.REQUIRED_ARGS + [
         "--debug_rule_config",
         '[{"RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image"}]',
     ])
     response = self.component._create_job_request(spec.inputs,
                                                   spec.outputs)
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["RuleConfigurationName"],
         "rule_name")
     self.assertEqual(
         response["DebugRuleConfigurations"][0]["RuleEvaluatorImage"],
         "test-image")
コード例 #15
0
    def test_valid_hyperparameters(self):
        hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}'

        spec = SageMakerTrainingSpec(
            self.REQUIRED_ARGS + ["--hyperparameters", hyperparameters_str])
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)

        self.assertIn("hp1", response["HyperParameters"])
        self.assertIn("hp2", response["HyperParameters"])
        self.assertIn("hp3", response["HyperParameters"])
        self.assertEqual(response["HyperParameters"]["hp1"], "val1")
        self.assertEqual(response["HyperParameters"]["hp2"], "val2")
        self.assertEqual(response["HyperParameters"]["hp3"], "val3")
コード例 #16
0
    def test_vpc_configuration(self):
        spec = SageMakerTrainingSpec(self.REQUIRED_ARGS + [
            "--vpc_security_group_ids",
            "sg1,sg2",
            "--vpc_subnets",
            "subnet1,subnet2",
        ])
        response = self.component._create_job_request(spec.inputs,
                                                      spec.outputs)

        self.assertIn("VpcConfig", response)
        self.assertIn("sg1", response["VpcConfig"]["SecurityGroupIds"])
        self.assertIn("sg2", response["VpcConfig"]["SecurityGroupIds"])
        self.assertIn("subnet1", response["VpcConfig"]["Subnets"])
        self.assertIn("subnet2", response["VpcConfig"]["Subnets"])
コード例 #17
0
    def test_known_algorithm_value(self):
        # This passes an algorithm that is a known SageMaker algorithm name
        known_algorithm_args = self.REQUIRED_ARGS + [
            "--algorithm_name", "seq2seq"
        ]
        image_index = self.REQUIRED_ARGS.index("--image")
        # Cut out --image and it's associated value
        known_algorithm_args = (known_algorithm_args[:image_index] +
                                known_algorithm_args[image_index + 2:])

        spec = SageMakerTrainingSpec(known_algorithm_args)

        # Patch retrieve
        with patch(
                "train.src.sagemaker_training_component.retrieve",
                MagicMock(return_value="seq2seq-url"),
        ) as mock_retrieve:
            response = self.component._create_job_request(
                spec.inputs, spec.outputs)

        mock_retrieve.assert_called_with("seq2seq", "us-west-2")
        self.assertEqual(response["AlgorithmSpecification"]["TrainingImage"],
                         "seq2seq-url")
コード例 #18
0
    def test_unknown_algorithm(self):
        known_algorithm_args = self.REQUIRED_ARGS + [
            "--algorithm_name",
            "unknown algorithm",
        ]
        image_index = self.REQUIRED_ARGS.index("--image")
        # Cut out --image and it's associated value
        known_algorithm_args = (known_algorithm_args[:image_index] +
                                known_algorithm_args[image_index + 2:])

        spec = SageMakerTrainingSpec(known_algorithm_args)

        # Patch retrieve
        with patch(
                "train.src.sagemaker_training_component.retrieve",
                MagicMock(return_value="unknown-url"),
        ) as mock_retrieve:
            response = self.component._create_job_request(
                spec.inputs, spec.outputs)

        # Should just place the algorithm name in regardless
        mock_retrieve.assert_not_called()
        self.assertEqual(response["AlgorithmSpecification"]["AlgorithmName"],
                         "unknown algorithm")
コード例 #19
0
        job: object,
        request: Dict,
        inputs: SageMakerTrainingInputs,
        outputs: SageMakerTrainingOutputs,
    ):
        logging.info(
            f"Created Training Job with name: {self._training_job_name}")
        logging.info(
            "Training job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}"
            .format(
                inputs.region,
                inputs.region,
                self._training_job_name,
            ))
        logging.info(
            "CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix"
            .format(
                inputs.region,
                inputs.region,
                self._training_job_name,
            ))


if __name__ == "__main__":
    import sys

    spec = SageMakerTrainingSpec(sys.argv[1:])

    component = SageMakerTrainingComponent()
    component.Do(spec)
コード例 #20
0
 def test_minimum_required_args(self):
     # Will raise if the inputs are incorrect
     spec = SageMakerTrainingSpec(self.REQUIRED_ARGS)