예제 #1
0
 def _append_register_model_step(self):
     """Create and append a `_RegisterModelStep`"""
     register_model_step = _RegisterModelStep(
         name="{}-{}".format(self.name, _REGISTER_MODEL_NAME_BASE),
         step_args=self._register_model_args,
         display_name=self.display_name,
         retry_policies=self._register_model_retry_policies,
         description=self.description,
     )
     if not self._need_runtime_repack:
         register_model_step.add_depends_on(self.depends_on)
     self.steps.append(register_model_step)
    def __init__(
        self,
        name: str,
        estimator: EstimatorBase,
        model_data,
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        depends_on: List[str] = None,
        model_package_group_name=None,
        model_metrics=None,
        approval_status=None,
        image_uri=None,
        compile_model_family=None,
        description=None,
        **kwargs,
    ):
        """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.

        Args:
            name (str): The name of the training step.
            estimator: The estimator instance.
            model_data: The S3 uri to the model data from training.
            content_types (list): The supported MIME types for the input data (default: None).
            response_types (list): The supported MIME types for the output data (default: None).
            inference_instances (list): A list of the instance types that are used to
                generate inferences in real-time (default: None).
            transform_instances (list): A list of the instance types on which a transformation
                job can be run or on which an endpoint can be deployed (default: None).
            depends_on (List[str]): The list of step names the first step in the collection
                depends on
            model_package_group_name (str): The Model Package Group name, exclusive to
                `model_package_name`, using `model_package_group_name` makes the Model Package
                versioned (default: None).
            model_metrics (ModelMetrics): ModelMetrics object (default: None).
            approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
                or "PendingManualApproval" (default: "PendingManualApproval").
            image_uri (str): The container image uri for Model Package, if not specified,
                Estimator's training container image is used (default: None).
            compile_model_family (str): The instance family for the compiled model. If
                specified, a compiled model is used (default: None).
            description (str): Model Package description (default: None).
            **kwargs: additional arguments to `create_model`.
        """
        steps: List[Step] = []
        repack_model = False
        if "entry_point" in kwargs:
            repack_model = True
            entry_point = kwargs["entry_point"]
            source_dir = kwargs.get("source_dir")
            dependencies = kwargs.get("dependencies")
            repack_model_step = _RepackModelStep(
                name=f"{name}RepackModel",
                depends_on=depends_on,
                estimator=estimator,
                model_data=model_data,
                entry_point=entry_point,
                source_dir=source_dir,
                dependencies=dependencies,
            )
            steps.append(repack_model_step)
            model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts

        # remove kwargs consumed by model repacking step
        kwargs.pop("entry_point", None)
        kwargs.pop("source_dir", None)
        kwargs.pop("dependencies", None)

        register_model_step = _RegisterModelStep(
            name=name,
            estimator=estimator,
            model_data=model_data,
            content_types=content_types,
            response_types=response_types,
            inference_instances=inference_instances,
            transform_instances=transform_instances,
            model_package_group_name=model_package_group_name,
            model_metrics=model_metrics,
            approval_status=approval_status,
            image_uri=image_uri,
            compile_model_family=compile_model_family,
            description=description,
            **kwargs,
        )
        if not repack_model:
            register_model_step.add_depends_on(depends_on)

        steps.append(register_model_step)
        self.steps = steps
예제 #3
0
    def __init__(
        self,
        name: str,
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        estimator: EstimatorBase = None,
        model_data=None,
        depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
        repack_model_step_retry_policies: List[RetryPolicy] = None,
        register_model_step_retry_policies: List[RetryPolicy] = None,
        model_package_group_name=None,
        model_metrics=None,
        approval_status=None,
        image_uri=None,
        compile_model_family=None,
        display_name=None,
        description=None,
        tags=None,
        model: Union[Model, PipelineModel] = None,
        drift_check_baselines=None,
        customer_metadata_properties=None,
        **kwargs,
    ):
        """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.

        Args:
            name (str): The name of the training step.
            estimator: The estimator instance.
            model_data: The S3 uri to the model data from training.
            content_types (list): The supported MIME types for the input data (default: None).
            response_types (list): The supported MIME types for the output data (default: None).
            inference_instances (list): A list of the instance types that are used to
                generate inferences in real-time (default: None).
            transform_instances (list): A list of the instance types on which a transformation
                job can be run or on which an endpoint can be deployed (default: None).
            depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
                names or `Step` instances or `StepCollection` instances that the first step
                in the collection depends on (default: None).
            repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
                for the repack model step
            register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
                for register model step
            model_package_group_name (str): The Model Package Group name or Arn, exclusive to
                `model_package_name`, using `model_package_group_name` makes the Model Package
                versioned (default: None).
            model_metrics (ModelMetrics): ModelMetrics object (default: None).
            approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
                or "PendingManualApproval" (default: "PendingManualApproval").
            image_uri (str): The container image uri for Model Package, if not specified,
                Estimator's training container image is used (default: None).
            compile_model_family (str): The instance family for the compiled model. If
                specified, a compiled model is used (default: None).
            description (str): Model Package description (default: None).
            tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
                that tags will only be applied to newly created model package groups; if the
                name of an existing group is passed to "model_package_group_name",
                tags will not be applied.
            model (object or Model): A PipelineModel object that comprises a list of models
                which gets executed as a serial inference pipeline or a Model object.
            drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
            customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
                metadata properties (default: None).

            **kwargs: additional arguments to `create_model`.
        """
        self.name = name
        steps: List[Step] = []
        repack_model = False
        self.model_list = None
        self.container_def_list = None
        subnets = None
        security_group_ids = None

        if estimator is not None:
            subnets = estimator.subnets
            security_group_ids = estimator.security_group_ids
        elif model is not None and model.vpc_config is not None:
            subnets = model.vpc_config["Subnets"]
            security_group_ids = model.vpc_config["SecurityGroupIds"]

        if "entry_point" in kwargs:
            repack_model = True
            entry_point = kwargs.pop("entry_point", None)
            source_dir = kwargs.pop("source_dir", None)
            dependencies = kwargs.pop("dependencies", None)
            kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))

            repack_model_step = _RepackModelStep(
                name=f"{name}RepackModel",
                depends_on=depends_on,
                retry_policies=repack_model_step_retry_policies,
                sagemaker_session=estimator.sagemaker_session,
                role=estimator.role,
                model_data=model_data,
                entry_point=entry_point,
                source_dir=source_dir,
                dependencies=dependencies,
                tags=tags,
                subnets=subnets,
                security_group_ids=security_group_ids,
                description=description,
                display_name=display_name,
                **kwargs,
            )
            steps.append(repack_model_step)
            model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts

            # remove kwargs consumed by model repacking step
            kwargs.pop("output_kms_key", None)

        elif model is not None:
            if isinstance(model, PipelineModel):
                self.model_list = model.models
            elif isinstance(model, Model):
                self.model_list = [model]

            for model_entity in self.model_list:
                if estimator is not None:
                    sagemaker_session = estimator.sagemaker_session
                    role = estimator.role
                else:
                    sagemaker_session = model_entity.sagemaker_session
                    role = model_entity.role
                if hasattr(model_entity, "entry_point") and model_entity.entry_point is not None:
                    repack_model = True
                    entry_point = model_entity.entry_point
                    source_dir = model_entity.source_dir
                    dependencies = model_entity.dependencies
                    kwargs = dict(**kwargs, output_kms_key=model_entity.model_kms_key)
                    model_name = model_entity.name or model_entity._framework_name

                    repack_model_step = _RepackModelStep(
                        name=f"{model_name}RepackModel",
                        depends_on=depends_on,
                        retry_policies=repack_model_step_retry_policies,
                        sagemaker_session=sagemaker_session,
                        role=role,
                        model_data=model_entity.model_data,
                        entry_point=entry_point,
                        source_dir=source_dir,
                        dependencies=dependencies,
                        tags=tags,
                        subnets=subnets,
                        security_group_ids=security_group_ids,
                        description=description,
                        display_name=display_name,
                        **kwargs,
                    )
                    steps.append(repack_model_step)
                    model_entity.model_data = (
                        repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
                    )

                    # remove kwargs consumed by model repacking step
                    kwargs.pop("output_kms_key", None)

            if isinstance(model, PipelineModel):
                self.container_def_list = model.pipeline_container_def(inference_instances[0])
            elif isinstance(model, Model):
                self.container_def_list = [model.prepare_container_def(inference_instances[0])]

        register_model_step = _RegisterModelStep(
            name=name,
            estimator=estimator,
            model_data=model_data,
            content_types=content_types,
            response_types=response_types,
            inference_instances=inference_instances,
            transform_instances=transform_instances,
            model_package_group_name=model_package_group_name,
            model_metrics=model_metrics,
            drift_check_baselines=drift_check_baselines,
            approval_status=approval_status,
            image_uri=image_uri,
            compile_model_family=compile_model_family,
            description=description,
            display_name=display_name,
            tags=tags,
            container_def_list=self.container_def_list,
            retry_policies=register_model_step_retry_policies,
            customer_metadata_properties=customer_metadata_properties,
            **kwargs,
        )
        if not repack_model:
            register_model_step.add_depends_on(depends_on)

        steps.append(register_model_step)
        self.steps = steps

        # TODO: add public document link here once ready
        warnings.warn(
            (
                "We are deprecating the use of RegisterModel. "
                "Instead, please use the ModelStep, which simply takes in the step arguments "
                "generated by model.register()."
            ),
            DeprecationWarning,
        )