def test_spot_bad_args(self): no_max_wait_args = self.SpotInstanceSpec(["--spot_instance", "True"]) no_checkpoint_args = self.SpotInstanceSpec( ["--spot_instance", "True", "--max_wait_time", "3601"]) no_s3_uri_args = self.SpotInstanceSpec([ "--spot_instance", "True", "--max_wait_time", "3601", "--max_run_time", "3600", "--checkpoint_config", "{}", ]) max_wait_too_short_args = self.SpotInstanceSpec([ "--spot_instance", "True", "--max_wait_time", "3600", "--max_run_time", "3601", "--checkpoint_config", "{}", ]) for arg in [ no_max_wait_args, no_checkpoint_args, no_s3_uri_args, max_wait_too_short_args, ]: with self.assertRaises(Exception): SageMakerComponent._enable_spot_instance_support( self.template, arg.inputs)
def test_create_hyperparameters(self): valid_params = {"tag1": "val1", "tag2": "val2"} invalid_params = {"tag1": 1500} self.assertEqual( valid_params, SageMakerComponent._validate_hyperparameters(valid_params)) with self.assertRaises(Exception): SageMakerComponent._validate_hyperparameters(invalid_params)
def Do(self, spec: SageMakerDeploySpec): # Manually invoke AWS client configuration so we can use it before # starting the reconciliation loop self._configure_aws_clients(spec.inputs) name_suffix = SageMakerComponent._generate_unique_timestamped_id() self._endpoint_name = (spec.inputs.endpoint_name if spec.inputs.endpoint_name else f"Endpoint{name_suffix}") self._should_update_existing = (spec.inputs.update_endpoint and self._endpoint_name_exists( spec.inputs.endpoint_name)) # Fetch existing config to delete after creation if self._should_update_existing: self._existing_endpoint_config_name = self._get_endpoint_config( spec.inputs.endpoint_name) self._endpoint_config_name = ( spec.inputs.endpoint_config_name # Only use the predefined name if we are not updating, otherwise could conflict if (spec.inputs.endpoint_config_name and not self._should_update_existing) else f"EndpointConfig{name_suffix}") super().Do(spec.inputs, spec.outputs, spec.output_paths)
def setUp(cls): cls.component = SageMakerComponent() # Turn off polling interval for instant tests cls.component.STATUS_POLL_INTERVAL = 0 cls.boto3_manager_patch = patch( "common.sagemaker_component.Boto3Manager") cls.boto3_manager_patch.start()
def test_tags(self): spec = self.CommonInputsSpec([ "--region", "us-east-1", "--tags", '{"key1": "val1", "key2": "val2"}' ]) response = SageMakerComponent._enable_tag_support( self.template, spec.inputs) self.assertIn({"Key": "key1", "Value": "val1"}, self.template["Tags"]) self.assertIn({"Key": "key2", "Value": "val2"}, self.template["Tags"])
def test_spot_local_path(self): args = self.SpotInstanceSpec([ "--spot_instance", "True", "--max_wait_time", "3601", "--max_run_time", "3600", "--checkpoint_config", '{"S3Uri": "s3://fake-uri/", "LocalPath": "local-path"}', ]) response = SageMakerComponent._enable_spot_instance_support( self.template, args.inputs) self.assertEqual(response["CheckpointConfig"]["S3Uri"], "s3://fake-uri/") self.assertEqual(response["CheckpointConfig"]["LocalPath"], "local-path")
def test_spot_good_args(self): good_args = self.SpotInstanceSpec([ "--spot_instance", "True", "--max_wait_time", "3601", "--max_run_time", "3600", "--checkpoint_config", '{"S3Uri": "s3://fake-uri/"}', ]) response = SageMakerComponent._enable_spot_instance_support( self.template, good_args.inputs) self.assertTrue(response["EnableManagedSpotTraining"]) self.assertEqual(response["StoppingCondition"]["MaxWaitTimeInSeconds"], 3601) self.assertEqual(response["CheckpointConfig"]["S3Uri"], "s3://fake-uri/")
def Do(self, spec: SageMakerTrainingSpec): self._training_job_name = ( spec.inputs.job_name if spec.inputs.job_name else SageMakerComponent._generate_unique_timestamped_id( prefix="TrainingJob")) super().Do(spec.inputs, spec.outputs, spec.output_paths)
def test_spot_disabled_deletes_args(self): args = self.SpotInstanceSpec(["--spot_instance", "False"]) response = SageMakerComponent._enable_spot_instance_support( self.template, args.inputs) self.assertNotIn("MaxWaitTimeInSeconds", response["StoppingCondition"]) self.assertNotIn("CheckpointConfig", response)
def setUp(cls): # Load the train template as an example cls.template = SageMakerComponent._get_request_template("train")
def Do(self, spec: SageMakerRLEstimatorSpec): self._rlestimator_job_name = ( spec.inputs.job_name if spec.inputs.job_name else SageMakerComponent._generate_unique_timestamped_id( prefix="RLEstimatorJob")) super().Do(spec.inputs, spec.outputs, spec.output_paths)