def _append_register_model_step(self): """Create and append a `_RegisterModelStep`""" register_model_step = _RegisterModelStep( name="{}-{}".format(self.name, _REGISTER_MODEL_NAME_BASE), step_args=self._register_model_args, display_name=self.display_name, retry_policies=self._register_model_retry_policies, description=self.description, ) if not self._need_runtime_repack: register_model_step.add_depends_on(self.depends_on) self.steps.append(register_model_step)
def __init__( self, name: str, estimator: EstimatorBase, model_data, content_types, response_types, inference_instances, transform_instances, depends_on: List[str] = None, model_package_group_name=None, model_metrics=None, approval_status=None, image_uri=None, compile_model_family=None, description=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. Args: name (str): The name of the training step. estimator: The estimator instance. model_data: The S3 uri to the model data from training. content_types (list): The supported MIME types for the input data (default: None). response_types (list): The supported MIME types for the output data (default: None). inference_instances (list): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). depends_on (List[str]): The list of step names the first step in the collection depends on model_package_group_name (str): The Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). image_uri (str): The container image uri for Model Package, if not specified, Estimator's training container image is used (default: None). compile_model_family (str): The instance family for the compiled model. If specified, a compiled model is used (default: None). description (str): Model Package description (default: None). **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] repack_model = False if "entry_point" in kwargs: repack_model = True entry_point = kwargs["entry_point"] source_dir = kwargs.get("source_dir") dependencies = kwargs.get("dependencies") repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, estimator=estimator, model_data=model_data, entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, ) steps.append(repack_model_step) model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts # remove kwargs consumed by model repacking step kwargs.pop("entry_point", None) kwargs.pop("source_dir", None) kwargs.pop("dependencies", None) register_model_step = _RegisterModelStep( name=name, estimator=estimator, model_data=model_data, content_types=content_types, response_types=response_types, inference_instances=inference_instances, transform_instances=transform_instances, model_package_group_name=model_package_group_name, model_metrics=model_metrics, approval_status=approval_status, image_uri=image_uri, compile_model_family=compile_model_family, description=description, **kwargs, ) if not repack_model: register_model_step.add_depends_on(depends_on) steps.append(register_model_step) self.steps = steps
def __init__( self, name: str, content_types, response_types, inference_instances, transform_instances, estimator: EstimatorBase = None, model_data=None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, repack_model_step_retry_policies: List[RetryPolicy] = None, register_model_step_retry_policies: List[RetryPolicy] = None, model_package_group_name=None, model_metrics=None, approval_status=None, image_uri=None, compile_model_family=None, display_name=None, description=None, tags=None, model: Union[Model, PipelineModel] = None, drift_check_baselines=None, customer_metadata_properties=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. Args: name (str): The name of the training step. estimator: The estimator instance. model_data: The S3 uri to the model data from training. content_types (list): The supported MIME types for the input data (default: None). response_types (list): The supported MIME types for the output data (default: None). inference_instances (list): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` names or `Step` instances or `StepCollection` instances that the first step in the collection depends on (default: None). repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for register model step model_package_group_name (str): The Model Package Group name or Arn, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). image_uri (str): The container image uri for Model Package, if not specified, Estimator's training container image is used (default: None). compile_model_family (str): The instance family for the compiled model. If specified, a compiled model is used (default: None). description (str): Model Package description (default: None). tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note that tags will only be applied to newly created model package groups; if the name of an existing group is passed to "model_package_group_name", tags will not be applied. model (object or Model): A PipelineModel object that comprises a list of models which gets executed as a serial inference pipeline or a Model object. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). **kwargs: additional arguments to `create_model`. """ self.name = name steps: List[Step] = [] repack_model = False self.model_list = None self.container_def_list = None subnets = None security_group_ids = None if estimator is not None: subnets = estimator.subnets security_group_ids = estimator.security_group_ids elif model is not None and model.vpc_config is not None: subnets = model.vpc_config["Subnets"] security_group_ids = model.vpc_config["SecurityGroupIds"] if "entry_point" in kwargs: repack_model = True entry_point = kwargs.pop("entry_point", None) source_dir = kwargs.pop("source_dir", None) dependencies = kwargs.pop("dependencies", None) kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None)) repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, role=estimator.role, model_data=model_data, entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, tags=tags, subnets=subnets, security_group_ids=security_group_ids, description=description, display_name=display_name, **kwargs, ) steps.append(repack_model_step) model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts # remove kwargs consumed by model repacking step kwargs.pop("output_kms_key", None) elif model is not None: if isinstance(model, PipelineModel): self.model_list = model.models elif isinstance(model, Model): self.model_list = [model] for model_entity in self.model_list: if estimator is not None: sagemaker_session = estimator.sagemaker_session role = estimator.role else: sagemaker_session = model_entity.sagemaker_session role = model_entity.role if hasattr(model_entity, "entry_point") and model_entity.entry_point is not None: repack_model = True entry_point = model_entity.entry_point source_dir = model_entity.source_dir dependencies = model_entity.dependencies kwargs = dict(**kwargs, output_kms_key=model_entity.model_kms_key) model_name = model_entity.name or model_entity._framework_name repack_model_step = _RepackModelStep( name=f"{model_name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=sagemaker_session, role=role, model_data=model_entity.model_data, entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, tags=tags, subnets=subnets, security_group_ids=security_group_ids, description=description, display_name=display_name, **kwargs, ) steps.append(repack_model_step) model_entity.model_data = ( repack_model_step.properties.ModelArtifacts.S3ModelArtifacts ) # remove kwargs consumed by model repacking step kwargs.pop("output_kms_key", None) if isinstance(model, PipelineModel): self.container_def_list = model.pipeline_container_def(inference_instances[0]) elif isinstance(model, Model): self.container_def_list = [model.prepare_container_def(inference_instances[0])] register_model_step = _RegisterModelStep( name=name, estimator=estimator, model_data=model_data, content_types=content_types, response_types=response_types, inference_instances=inference_instances, transform_instances=transform_instances, model_package_group_name=model_package_group_name, model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, approval_status=approval_status, image_uri=image_uri, compile_model_family=compile_model_family, description=description, display_name=display_name, tags=tags, container_def_list=self.container_def_list, retry_policies=register_model_step_retry_policies, customer_metadata_properties=customer_metadata_properties, **kwargs, ) if not repack_model: register_model_step.add_depends_on(depends_on) steps.append(register_model_step) self.steps = steps # TODO: add public document link here once ready warnings.warn( ( "We are deprecating the use of RegisterModel. " "Instead, please use the ModelStep, which simply takes in the step arguments " "generated by model.register()." ), DeprecationWarning, )