def test_do_sets_name(self):
        named_spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS +
                                              ["--job_name", "job-name"])
        unnamed_spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS)

        self.component.Do(named_spec)
        self.assertEqual("job-name", self.component._rlestimator_job_name)

        with patch(
                "rlestimator.src.sagemaker_rlestimator_component.SageMakerComponent._generate_unique_timestamped_id",
                MagicMock(return_value="unique"),
        ):
            self.component.Do(unnamed_spec)
            self.assertEqual("unique", self.component._rlestimator_job_name)
    def test_training_mode(self):
        spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS +
                                        ["--training_input_mode", "Pipe"])
        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)

        self.assertEqual(getattr(rlestimator, "input_mode"), "Pipe")
 def test_rule_max_args(self):
     spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS + [
         "--debug_rule_config",
         '[{"InstanceType": "ml.m4.xlarge", "LocalPath": "/local/path/", "RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image", "RuleParameters": {"key1": "value1"}, "S3OutputPath": "s3://fake-uri/", "VolumeSizeInGB": 1}]',
     ])
     rlestimator = self.component._create_job_request(
         spec.inputs, spec.outputs)
     attrs = vars(rlestimator)
     print(", ".join("%s: %s" % item for item in attrs.items()))
     print(getattr(rlestimator, "debugger_rule_configs"))
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["InstanceType"], "ml.m4.xlarge")
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["LocalPath"], "/local/path/")
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["RuleConfigurationName"],
         "rule_name")
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["RuleEvaluatorImage"],
         "test-image")
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["RuleParameters"],
         {"key1": "value1"})
     self.assertEqual(
         getattr(rlestimator, "rules")[0]["S3OutputPath"], "s3://fake-uri/")
     self.assertEqual(getattr(rlestimator, "rules")[0]["VolumeSizeInGB"], 1)
    def test_object_hyperparameters(self):
        hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}'

        spec = SageMakerRLEstimatorSpec(
            self.CUSTOM_IMAGE_ARGS +
            ["--hyperparameters", hyperparameters_str])
        with self.assertRaises(Exception):
            self.component._create_job_request(spec.inputs, spec.outputs)
    def test_empty_hyperparameters(self):
        hyperparameters_str = "{}"

        spec = SageMakerRLEstimatorSpec(
            self.CUSTOM_IMAGE_ARGS +
            ["--hyperparameters", hyperparameters_str])
        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)

        self.assertEqual(getattr(rlestimator, "_hyperparameters"), {})
 def test_hook_min_args(self):
     spec = SageMakerRLEstimatorSpec(
         self.CUSTOM_IMAGE_ARGS +
         ["--debug_hook_config", '{"S3OutputPath": "s3://fake-uri/"}'])
     rlestimator = self.component._create_job_request(
         spec.inputs, spec.outputs)
     self.assertEqual(
         getattr(rlestimator, "debugger_hook_config")["S3OutputPath"],
         "s3://fake-uri/",
     )
 def test_create_rlestimator_toolkit_job(self):
     spec = SageMakerRLEstimatorSpec(self.TOOLKIT_IMAGE_ARGS)
     rlestimator = self.component._create_job_request(
         spec.inputs, spec.outputs)
     self.assertHasAttr(rlestimator, "role")
     self.assertHasAttr(rlestimator, "source_dir")
     self.assertHasAttr(rlestimator, "entry_point")
     self.assertHasAttr(rlestimator, "toolkit")
     self.assertHasAttr(rlestimator, "toolkit_version")
     self.assertHasAttr(rlestimator, "framework")
     self.assertAttrNone(rlestimator, "image_uri")
    def test_no_defined_image(self):
        # Pass the image to pass the parser
        no_image_args = self.CUSTOM_IMAGE_ARGS.copy()
        image_index = no_image_args.index("--image")
        # Cut out --image and it's associated value
        no_image_args = no_image_args[:image_index] + no_image_args[
            image_index + 2:]

        spec = SageMakerRLEstimatorSpec(no_image_args)

        with self.assertRaises(Exception):
            self.component._create_job_request(spec.inputs, spec.outputs)
    def test_after_job_completed(self):
        self.component._get_model_artifacts_from_job = MagicMock(
            return_value="model")
        self.component._get_image_from_job = MagicMock(return_value="image")

        spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS)

        self.component._after_job_complete({}, {}, spec.inputs, spec.outputs)

        self.assertEqual(spec.outputs.job_name, "test-job")
        self.assertEqual(spec.outputs.model_artifact_url, "model")
        self.assertEqual(spec.outputs.training_image, "image")
    def test_rule_min_good_args(self):
        spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS + [
            "--debug_rule_config",
            '[{"RuleConfigurationName": "rule_name", "RuleEvaluatorImage": "test-image"}]',
        ])
        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)

        self.assertEqual(
            getattr(rlestimator, "rules")[0]["RuleConfigurationName"],
            "rule_name")
        self.assertEqual(
            getattr(rlestimator, "rules")[0]["RuleEvaluatorImage"],
            "test-image")
    def test_vpc_configuration(self):
        spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS + [
            "--vpc_security_group_ids",
            '["sg1", "sg2"]',
            "--vpc_subnets",
            '["subnet1", "subnet2"]',
        ])
        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)

        self.assertHasAttr(rlestimator, "subnets")
        self.assertHasAttr(rlestimator, "security_group_ids")
        self.assertIn("sg1", getattr(rlestimator, "security_group_ids"))
        self.assertIn("sg2", getattr(rlestimator, "security_group_ids"))
        self.assertIn("subnet1", getattr(rlestimator, "subnets"))
        self.assertIn("subnet2", getattr(rlestimator, "subnets"))
    def test_valid_hyperparameters(self):
        hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}'

        spec = SageMakerRLEstimatorSpec(
            self.CUSTOM_IMAGE_ARGS +
            ["--hyperparameters", hyperparameters_str])
        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)

        self.assertIn("hp1", getattr(rlestimator, "_hyperparameters"))
        self.assertIn("hp2", getattr(rlestimator, "_hyperparameters"))
        self.assertIn("hp3", getattr(rlestimator, "_hyperparameters"))
        self.assertEqual(
            getattr(rlestimator, "_hyperparameters")["hp1"], "val1")
        self.assertEqual(
            getattr(rlestimator, "_hyperparameters")["hp2"], "val2")
        self.assertEqual(
            getattr(rlestimator, "_hyperparameters")["hp3"], "val3")
    def test_metric_definitions(self):
        spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS + [
            "--metric_definitions",
            '[ {"Name": "metric1", "Regex": "regexval1"},{"Name": "metric2", "Regex": "regexval2"},]',
        ])

        rlestimator = self.component._create_job_request(
            spec.inputs, spec.outputs)
        self.assertEqual(
            getattr(rlestimator, "metric_definitions"),
            [
                {
                    "Name": "metric1",
                    "Regex": "regexval1"
                },
                {
                    "Name": "metric2",
                    "Regex": "regexval2"
                },
            ],
        )
 def test_hook_max_args(self):
     spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS + [
         "--debug_hook_config",
         '{"S3OutputPath": "s3://fake-uri/", "LocalPath": "/local/path/", "HookParameters": {"key": "value"}, "CollectionConfigurations": [{"CollectionName": "collection1", "CollectionParameters": {"key1": "value1"}}, {"CollectionName": "collection2", "CollectionParameters": {"key2": "value2", "key3": "value3"}}]}',
     ])
     rlestimator = self.component._create_job_request(
         spec.inputs, spec.outputs)
     self.assertEqual(
         getattr(rlestimator, "debugger_hook_config")["S3OutputPath"],
         "s3://fake-uri/",
     )
     self.assertEqual(
         getattr(rlestimator, "debugger_hook_config")["LocalPath"],
         "/local/path/")
     self.assertEqual(
         getattr(rlestimator, "debugger_hook_config")["HookParameters"],
         {"key": "value"},
     )
     self.assertEqual(
         getattr(rlestimator,
                 "debugger_hook_config")["CollectionConfigurations"],
         [
             {
                 "CollectionName": "collection1",
                 "CollectionParameters": {
                     "key1": "value1"
                 },
             },
             {
                 "CollectionName": "collection2",
                 "CollectionParameters": {
                     "key2": "value2",
                     "key3": "value3"
                 },
             },
         ],
     )
 def test_toolkit_image_args(self):
     # Will raise if the inputs are incorrect
     spec = SageMakerRLEstimatorSpec(self.TOOLKIT_IMAGE_ARGS)
 def test_custom_image_args(self):
     # Will raise if the inputs are incorrect
     spec = SageMakerRLEstimatorSpec(self.CUSTOM_IMAGE_ARGS)
예제 #17
0
                self._rlestimator_job_name,
            ))

    @staticmethod
    def _get_toolkit(toolkit_type: str) -> RLToolkit:
        if toolkit_type == "":
            return None
        return RLToolkit[toolkit_type.upper()]

    @staticmethod
    def _get_framework(framework_type: str) -> RLFramework:
        if framework_type == "":
            return None
        return RLFramework[framework_type.upper()]

    @staticmethod
    def _nullable(value: str):
        if value:
            return value
        else:
            return None


if __name__ == "__main__":
    import sys

    spec = SageMakerRLEstimatorSpec(sys.argv[1:])

    component = SageMakerRLEstimatorComponent()
    component.Do(spec)