예제 #1
0
    def test_spot_bad_args(self):
        no_max_wait_args = self.SpotInstanceSpec(["--spot_instance", "True"])
        no_checkpoint_args = self.SpotInstanceSpec(
            ["--spot_instance", "True", "--max_wait_time", "3601"])
        no_s3_uri_args = self.SpotInstanceSpec([
            "--spot_instance",
            "True",
            "--max_wait_time",
            "3601",
            "--max_run_time",
            "3600",
            "--checkpoint_config",
            "{}",
        ])
        max_wait_too_short_args = self.SpotInstanceSpec([
            "--spot_instance",
            "True",
            "--max_wait_time",
            "3600",
            "--max_run_time",
            "3601",
            "--checkpoint_config",
            "{}",
        ])

        for arg in [
                no_max_wait_args,
                no_checkpoint_args,
                no_s3_uri_args,
                max_wait_too_short_args,
        ]:
            with self.assertRaises(Exception):
                SageMakerComponent._enable_spot_instance_support(
                    self.template, arg.inputs)
예제 #2
0
    def test_create_hyperparameters(self):
        valid_params = {"tag1": "val1", "tag2": "val2"}
        invalid_params = {"tag1": 1500}

        self.assertEqual(
            valid_params,
            SageMakerComponent._validate_hyperparameters(valid_params))
        with self.assertRaises(Exception):
            SageMakerComponent._validate_hyperparameters(invalid_params)
    def Do(self, spec: SageMakerDeploySpec):
        # Manually invoke AWS client configuration so we can use it before
        # starting the reconciliation loop
        self._configure_aws_clients(spec.inputs)

        name_suffix = SageMakerComponent._generate_unique_timestamped_id()

        self._endpoint_name = (spec.inputs.endpoint_name
                               if spec.inputs.endpoint_name else
                               f"Endpoint{name_suffix}")

        self._should_update_existing = (spec.inputs.update_endpoint
                                        and self._endpoint_name_exists(
                                            spec.inputs.endpoint_name))

        # Fetch existing config to delete after creation
        if self._should_update_existing:
            self._existing_endpoint_config_name = self._get_endpoint_config(
                spec.inputs.endpoint_name)

        self._endpoint_config_name = (
            spec.inputs.endpoint_config_name
            # Only use the predefined name if we are not updating, otherwise could conflict
            if (spec.inputs.endpoint_config_name
                and not self._should_update_existing) else
            f"EndpointConfig{name_suffix}")

        super().Do(spec.inputs, spec.outputs, spec.output_paths)
예제 #4
0
 def setUp(cls):
     cls.component = SageMakerComponent()
     # Turn off polling interval for instant tests
     cls.component.STATUS_POLL_INTERVAL = 0
     cls.boto3_manager_patch = patch(
         "common.sagemaker_component.Boto3Manager")
     cls.boto3_manager_patch.start()
예제 #5
0
 def test_tags(self):
     spec = self.CommonInputsSpec([
         "--region", "us-east-1", "--tags",
         '{"key1": "val1", "key2": "val2"}'
     ])
     response = SageMakerComponent._enable_tag_support(
         self.template, spec.inputs)
     self.assertIn({"Key": "key1", "Value": "val1"}, self.template["Tags"])
     self.assertIn({"Key": "key2", "Value": "val2"}, self.template["Tags"])
예제 #6
0
 def test_spot_local_path(self):
     args = self.SpotInstanceSpec([
         "--spot_instance",
         "True",
         "--max_wait_time",
         "3601",
         "--max_run_time",
         "3600",
         "--checkpoint_config",
         '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}',
     ])
     response = SageMakerComponent._enable_spot_instance_support(
         self.template, args.inputs)
     self.assertEqual(response["CheckpointConfig"]["S3Uri"],
                      "s3://fake-uri/")
     self.assertEqual(response["CheckpointConfig"]["LocalPath"],
                      "local-path")
예제 #7
0
 def test_spot_good_args(self):
     good_args = self.SpotInstanceSpec([
         "--spot_instance",
         "True",
         "--max_wait_time",
         "3601",
         "--max_run_time",
         "3600",
         "--checkpoint_config",
         '{"S3Uri": "s3://fake-uri/"}',
     ])
     response = SageMakerComponent._enable_spot_instance_support(
         self.template, good_args.inputs)
     self.assertTrue(response["EnableManagedSpotTraining"])
     self.assertEqual(response["StoppingCondition"]["MaxWaitTimeInSeconds"],
                      3601)
     self.assertEqual(response["CheckpointConfig"]["S3Uri"],
                      "s3://fake-uri/")
 def Do(self, spec: SageMakerTrainingSpec):
     self._training_job_name = (
         spec.inputs.job_name if spec.inputs.job_name else
         SageMakerComponent._generate_unique_timestamped_id(
             prefix="TrainingJob"))
     super().Do(spec.inputs, spec.outputs, spec.output_paths)
예제 #9
0
 def test_spot_disabled_deletes_args(self):
     args = self.SpotInstanceSpec(["--spot_instance", "False"])
     response = SageMakerComponent._enable_spot_instance_support(
         self.template, args.inputs)
     self.assertNotIn("MaxWaitTimeInSeconds", response["StoppingCondition"])
     self.assertNotIn("CheckpointConfig", response)
예제 #10
0
 def setUp(cls):
     # Load the train template as an example
     cls.template = SageMakerComponent._get_request_template("train")
예제 #11
0
 def Do(self, spec: SageMakerRLEstimatorSpec):
     self._rlestimator_job_name = (
         spec.inputs.job_name if spec.inputs.job_name else
         SageMakerComponent._generate_unique_timestamped_id(
             prefix="RLEstimatorJob"))
     super().Do(spec.inputs, spec.outputs, spec.output_paths)