コード例 #1
0
    def test_do_sets_name(self):
        given_endpoint_name = SageMakerDeploySpec(
            self.REQUIRED_ARGS + ["--endpoint_name", "my-endpoint"])
        given_endpoint_config_name = SageMakerDeploySpec(
            self.REQUIRED_ARGS +
            ["--endpoint_config_name", "my-endpoint-config"])
        unnamed_spec = SageMakerDeploySpec(self.REQUIRED_ARGS)

        with patch(
                "deploy.src.sagemaker_deploy_component.SageMakerComponent._generate_unique_timestamped_id",
                MagicMock(return_value="-generated"),
        ):
            self.component.Do(given_endpoint_name)
            self.assertEqual("EndpointConfig-generated",
                             self.component._endpoint_config_name)
            self.assertEqual("my-endpoint", self.component._endpoint_name)

            self.component.Do(given_endpoint_config_name)
            self.assertEqual("my-endpoint-config",
                             self.component._endpoint_config_name)
            self.assertEqual("Endpoint-generated",
                             self.component._endpoint_name)

            self.component.Do(unnamed_spec)
            self.assertEqual("EndpointConfig-generated",
                             self.component._endpoint_config_name)
            self.assertEqual("Endpoint-generated",
                             self.component._endpoint_name)
コード例 #2
0
    def test_create_deploy_job_requests(self):
        spec = SageMakerDeploySpec(self.REQUIRED_ARGS)
        request = self.component._create_job_request(spec.inputs, spec.outputs)

        self.assertEqual(
            request,
            EndpointRequests(
                config_request={
                    "EndpointConfigName":
                    "endpoint-config",
                    "ProductionVariants": [{
                        "VariantName": "variant-name-1",
                        "ModelName": "model-test",
                        "InitialInstanceCount": 1,
                        "InstanceType": "ml.m4.xlarge",
                        "InitialVariantWeight": 1.0,
                    }],
                    "Tags": [],
                },
                endpoint_request={
                    "EndpointName": "endpoint",
                    "EndpointConfigName": "endpoint-config",
                },
            ),
        )
コード例 #3
0
    def test_create_deploy_job_multiple_variants(self):
        spec = SageMakerDeploySpec(self.REQUIRED_ARGS + [
            "--variant_name_1",
            "variant-test-1",
            "--initial_instance_count_1",
            "1",
            "--instance_type_1",
            "t1",
            "--initial_variant_weight_1",
            "0.1",
            "--accelerator_type_1",
            "ml.eia1.medium",
            "--model_name_2",
            "model-test-2",
            "--variant_name_2",
            "variant-test-2",
            "--initial_instance_count_2",
            "2",
            "--instance_type_2",
            "t2",
            "--initial_variant_weight_2",
            "0.2",
            "--accelerator_type_2",
            "ml.eia1.large",
        ])

        request = self.component._create_job_request(spec.inputs, spec.outputs)

        self.assertEqual(
            request,
            EndpointRequests(
                config_request={
                    "EndpointConfigName":
                    "endpoint-config",
                    "ProductionVariants": [
                        {
                            "VariantName": "variant-test-1",
                            "ModelName": "model-test",
                            "InitialInstanceCount": 1,
                            "InstanceType": "t1",
                            "InitialVariantWeight": 0.1,
                            "AcceleratorType": "ml.eia1.medium",
                        },
                        {
                            "VariantName": "variant-test-2",
                            "ModelName": "model-test-2",
                            "InitialInstanceCount": 2,
                            "InstanceType": "t2",
                            "InitialVariantWeight": 0.2,
                            "AcceleratorType": "ml.eia1.large",
                        },
                    ],
                    "Tags": [],
                },
                endpoint_request={
                    "EndpointName": "endpoint",
                    "EndpointConfigName": "endpoint-config",
                },
            ),
        )
コード例 #4
0
    def test_update_endpoint_do_sets_name(self):
        given_endpoint_name = SageMakerDeploySpec(
            self.REQUIRED_ARGS +
            ["--endpoint_name", "my-endpoint", "--update_endpoint", "True"])
        given_endpoint_config_name = SageMakerDeploySpec(self.REQUIRED_ARGS + [
            "--endpoint_config_name",
            "my-endpoint-config",
            "--update_endpoint",
            "True",
        ])
        unnamed_spec = SageMakerDeploySpec(self.REQUIRED_ARGS)
        SageMakerDeployComponent._generate_unique_timestamped_id = MagicMock(
            return_value="-generated-update")
        self.component._endpoint_name_exists = MagicMock(return_value=True)
        self.component._get_endpoint_config = MagicMock(
            return_value="existing-config")

        with patch(
                "deploy.src.sagemaker_deploy_component.SageMakerComponent._generate_unique_timestamped_id",
                MagicMock(return_value="-generated-update"),
        ):
            self.component.Do(given_endpoint_name)
            self.assertEqual("EndpointConfig-generated-update",
                             self.component._endpoint_config_name)
            self.assertEqual("my-endpoint", self.component._endpoint_name)
            self.assertTrue(self.component._should_update_existing)

            # Ignore given endpoint config name for update
            self.component.Do(given_endpoint_config_name)
            self.assertEqual("EndpointConfig-generated-update",
                             self.component._endpoint_config_name)
            self.assertEqual("Endpoint-generated-update",
                             self.component._endpoint_name)
            self.assertTrue(self.component._should_update_existing)

            self.component.Do(unnamed_spec)
            self.assertEqual("EndpointConfig-generated-update",
                             self.component._endpoint_config_name)
            self.assertEqual("Endpoint-generated-update",
                             self.component._endpoint_name)
            self.assertFalse(self.component._should_update_existing)
コード例 #5
0
            logging.info("Endpoint Config does not exist")
            ## This is not an error, end point may not exist
        return endpoint_config_name

    def _delete_endpoint_config(self, endpoint_config_name: str):
        """Deletes an endpoint config.

        Args:
            endpoint_config_name: The name of the endpoint config to delete.

        Returns:
            True if the endpoint was deleted, False otherwise.
        """
        try:
            self._sm_client.delete_endpoint_config(
                EndpointConfigName=endpoint_config_name)
            return True
        except ClientError as e:
            logging.info("Endpoint config may not exist to be deleted: " +
                         e.response["Error"]["Message"])
        return False


if __name__ == "__main__":
    import sys

    spec = SageMakerDeploySpec(sys.argv[1:])

    component = SageMakerDeployComponent()
    component.Do(spec)
コード例 #6
0
 def test_minimum_required_args(self):
     # Will raise if the inputs are incorrect
     spec = SageMakerDeploySpec(self.REQUIRED_ARGS)
コード例 #7
0
    def test_after_job_completed(self):
        spec = SageMakerDeploySpec(self.REQUIRED_ARGS)

        self.component._after_job_complete({}, {}, spec.inputs, spec.outputs)

        self.assertEqual(spec.outputs.endpoint_name, "endpoint")