コード例 #1
0
def model(sagemaker_session):
    return FrameworkModel(
        image_uri=IMAGE_URI,
        model_data=f"s3://{BUCKET}/model.tar.gz",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        entry_point=f"{DATA_DIR}/dummy_script.py",
        name="modelName",
        vpc_config={"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]},
    )
コード例 #2
0
    def create_model(
        self,
        role=None,
        vpc_config_override=VPC_CONFIG_DEFAULT,
        entry_point=None,
        source_dir=None,
        dependencies=None,
    ):
        """Create a SageMaker ``RLEstimatorModel`` object that can be deployed
        to an Endpoint.

        Args:
            role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
                which is also used during transform jobs. If not specified, the
                role from the Estimator will be used.
            vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
                the model. Default: use subnets and security groups from this Estimator.
                * 'Subnets' (list[str]): List of subnet ids.
                * 'SecurityGroupIds' (list[str]): List of security group ids.
            entry_point (str): Path (absolute or relative) to the Python source
                file which should be executed as the entry point for MXNet
                hosting. This should be compatible with Python 3.5 (default:
                self.entry_point)
            source_dir (str): Path (absolute or relative) to a directory with
                any other training source code dependencies aside from tne entry
                point file (default: self.source_dir). Structure within this
                directory are preserved when hosting on Amazon SageMaker.
            dependencies (list[str]): A list of paths to directories (absolute
                or relative) with any additional libraries that will be exported
                to the container (default: self.dependencies). The library
                folders will be copied to SageMaker in the same folder where the
                entry_point is copied. If the ```source_dir``` points to S3,
                code will be uploaded and the S3 location will be used instead.

        Returns:
            sagemaker.model.FrameworkModel: Depending on input parameters returns
                one of the following:

                * sagemaker.model.FrameworkModel - in case image_name was specified
                    on the estimator;
                * sagemaker.mxnet.MXNetModel - if image_name wasn't specified and
                    MXNet was used as RL backend;
                * sagemaker.tensorflow.serving.Model - if image_name wasn't specified and
                    TensorFlow was used as RL backend.
        Raises:
            ValueError: If image_name was not specified and framework enum is not valid.
        """
        base_args = dict(
            model_data=self.model_data,
            role=role or self.role,
            image=self.image_name,
            name=self._current_job_name,
            container_log_level=self.container_log_level,
            sagemaker_session=self.sagemaker_session,
            vpc_config=self.get_vpc_config(vpc_config_override),
        )

        if not entry_point and (source_dir or dependencies):
            raise AttributeError("Please provide an `entry_point`.")

        entry_point = entry_point or self.entry_point
        source_dir = source_dir or self._model_source_dir()
        dependencies = dependencies or self.dependencies

        extended_args = dict(
            entry_point=entry_point,
            source_dir=source_dir,
            code_location=self.code_location,
            dependencies=dependencies,
            enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
        )
        extended_args.update(base_args)

        if self.image_name:
            return FrameworkModel(**extended_args)

        if self.toolkit == RLToolkit.RAY.value:
            raise NotImplementedError(
                "Automatic deployment of Ray models is not currently available."
                " Train policy parameters are available in model checkpoints"
                " in the TrainingJob output.")

        if self.framework == RLFramework.TENSORFLOW.value:
            from sagemaker.tensorflow.serving import Model as tfsModel

            return tfsModel(framework_version=self.framework_version,
                            **base_args)
        if self.framework == RLFramework.MXNET.value:
            return MXNetModel(framework_version=self.framework_version,
                              py_version=PYTHON_VERSION,
                              **extended_args)
        raise ValueError(
            "An unknown RLFramework enum was passed in. framework: {}".format(
                self.framework))