def update_endpoint(
        self,
        initial_instance_count=None,
        instance_type=None,
        accelerator_type=None,
        model_name=None,
        tags=None,
        kms_key=None,
        data_capture_config_dict=None,
        wait=True,
    ):
        """Update the existing endpoint with the provided attributes.

        This creates a new EndpointConfig in the process. If ``initial_instance_count``,
        ``instance_type``, ``accelerator_type``, or ``model_name`` is specified, then a new
        ProductionVariant configuration is created; values from the existing configuration
        are not preserved if any of those parameters are specified.

        Args:
            initial_instance_count (int): The initial number of instances to run in the endpoint.
                This is required if ``instance_type``, ``accelerator_type``, or ``model_name`` is
                specified. Otherwise, the values from the existing endpoint configuration's
                ProductionVariants are used.
            instance_type (str): The EC2 instance type to deploy the endpoint to.
                This is required if ``initial_instance_count`` or ``accelerator_type`` is specified.
                Otherwise, the values from the existing endpoint configuration's
                ``ProductionVariants`` are used.
            accelerator_type (str): The type of Elastic Inference accelerator to attach to
                the endpoint, e.g. "ml.eia1.medium". If not specified, and
                ``initial_instance_count``, ``instance_type``, and ``model_name`` are also ``None``,
                the values from the existing endpoint configuration's ``ProductionVariants`` are
                used. Otherwise, no Elastic Inference accelerator is attached to the endpoint.
            model_name (str): The name of the model to be associated with the endpoint.
                This is required if ``initial_instance_count``, ``instance_type``, or
                ``accelerator_type`` is specified and if there is more than one model associated
                with the endpoint. Otherwise, the existing model for the endpoint is used.
            tags (list[dict[str, str]]): The list of tags to add to the endpoint
                config. If not specified, the tags of the existing endpoint configuration are used.
                If any of the existing tags are reserved AWS ones (i.e. begin with "aws"),
                they are not carried over to the new endpoint configuration.
            kms_key (str): The KMS key that is used to encrypt the data on the storage volume
                attached to the instance hosting the endpoint If not specified,
                the KMS key of the existing endpoint configuration is used.
            data_capture_config_dict (dict): The endpoint data capture configuration
                for use with Amazon SageMaker Model Monitoring. If not specified,
                the data capture configuration of the existing endpoint configuration is used.

        Raises:
            ValueError: If there is not enough information to create a new ``ProductionVariant``:

                - If ``initial_instance_count``, ``accelerator_type``, or ``model_name`` is
                  specified, but ``instance_type`` is ``None``.
                - If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is
                  specified and either ``model_name`` is ``None`` or there are multiple models
                  associated with the endpoint.
        """
        production_variants = None

        if initial_instance_count or instance_type or accelerator_type or model_name:
            if instance_type is None or initial_instance_count is None:
                raise ValueError(
                    "Missing initial_instance_count and/or instance_type. Provided values: "
                    "initial_instance_count={}, instance_type={}, accelerator_type={}, "
                    "model_name={}.".format(initial_instance_count,
                                            instance_type, accelerator_type,
                                            model_name))

            if model_name is None:
                if len(self._model_names) > 1:
                    raise ValueError(
                        "Unable to choose a default model for a new EndpointConfig because "
                        "the endpoint has multiple models: {}".format(
                            ", ".join(self._model_names)))
                model_name = self._model_names[0]
            else:
                self._model_names = [model_name]

            production_variant_config = production_variant(
                model_name,
                instance_type,
                initial_instance_count=initial_instance_count,
                accelerator_type=accelerator_type,
            )
            production_variants = [production_variant_config]

        new_endpoint_config_name = name_from_base(self._endpoint_config_name)
        self.sagemaker_session.create_endpoint_config_from_existing(
            self._endpoint_config_name,
            new_endpoint_config_name,
            new_tags=tags,
            new_kms_key=kms_key,
            new_data_capture_config_dict=data_capture_config_dict,
            new_production_variants=production_variants,
        )
        self.sagemaker_session.update_endpoint(self.endpoint_name,
                                               new_endpoint_config_name,
                                               wait=wait)
        self._endpoint_config_name = new_endpoint_config_name
def multi_variant_endpoint(sagemaker_session):
    """
    Sets up the multi variant endpoint before the integration tests run.
    Cleans up the multi variant endpoint after the integration tests run.
    """
    multi_variant_endpoint.endpoint_name = unique_name_from_base(
        "integ-test-multi-variant-endpoint")
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
            endpoint_name=multi_variant_endpoint.endpoint_name,
            sagemaker_session=sagemaker_session,
            hours=2,
    ):

        # Creating a model
        bucket = sagemaker_session.default_bucket()
        prefix = "sagemaker/DEMO-VariantTargeting"
        model_url = S3Uploader.upload(
            local_path=XG_BOOST_MODEL_LOCAL_PATH,
            desired_s3_uri="s3://" + bucket + "/" + prefix,
            session=sagemaker_session,
        )

        image_uri = get_image_uri(sagemaker_session.boto_session.region_name,
                                  "xgboost", "0.90-1")

        multi_variant_endpoint_model = sagemaker_session.create_model(
            name=MODEL_NAME,
            role=ROLE,
            container_defs={
                "Image": image_uri,
                "ModelDataUrl": model_url
            },
        )

        # Creating a multi variant endpoint
        variant1 = production_variant(
            model_name=MODEL_NAME,
            instance_type=DEFAULT_INSTANCE_TYPE,
            initial_instance_count=DEFAULT_INSTANCE_COUNT,
            variant_name=TEST_VARIANT_1,
            initial_weight=TEST_VARIANT_1_WEIGHT,
        )
        variant2 = production_variant(
            model_name=MODEL_NAME,
            instance_type=DEFAULT_INSTANCE_TYPE,
            initial_instance_count=DEFAULT_INSTANCE_COUNT,
            variant_name=TEST_VARIANT_2,
            initial_weight=TEST_VARIANT_2_WEIGHT,
        )
        sagemaker_session.endpoint_from_production_variants(
            name=multi_variant_endpoint.endpoint_name,
            production_variants=[variant1, variant2])

        # Yield to run the integration tests
        yield multi_variant_endpoint

        # Cleanup resources
        sagemaker_session.delete_model(multi_variant_endpoint_model)
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=multi_variant_endpoint.endpoint_name)

    # Validate resource cleanup
    with pytest.raises(Exception) as exception:
        sagemaker_session.sagemaker_client.describe_model(
            ModelName=multi_variant_endpoint_model.name)
        assert "Could not find model" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(
            name=multi_variant_endpoint.endpoint_name)
        assert "Could not find endpoint" in str(exception.value)