def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
                  job_name=None):
        """Start a new transform job.

        Args:
            data (str): Input data location in S3.
            data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:

                * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
                    inputs for the transform job.
                * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
                    an input for the transform job.

            content_type (str): MIME type of the input data (default: None).
            compression (str): Compression type of the input data, if compressed (default: None).
                Valid values: 'Gzip', None.
            split_type (str): The record delimiter for the input object (default: 'None').
                Valid values: 'None', 'Line', and 'RecordIO'.
            job_name (str): job name (default: None). If not specified, one will be generated.
        """
        if not data.startswith('s3://'):
            raise ValueError('Invalid S3 URI: {}'.format(data))

        if job_name is not None:
            self._current_job_name = job_name
        else:
            base_name = self.base_transform_job_name or base_name_from_image(self._retrieve_image_name())
            self._current_job_name = name_from_base(base_name)

        if self.output_path is None:
            self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name)

        self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
                                                            split_type)
Esempio n. 2
0
    def _prepare_for_training(self, job_name=None):
        """Set any values in the estimator that need to be set before training.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code',
                                          self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(
                    self.sagemaker_session.default_bucket())
Esempio n. 3
0
    def _prepare_for_training(self, job_name=None):
        if job_name is not None:
            self._current_job_name = job_name
        else:
            base_name = self.base_tuning_job_name or base_name_from_image(
                self.estimator.train_image())
            self._current_job_name = name_from_base(
                base_name,
                max_length=self.TUNING_JOB_NAME_MAX_LENGTH,
                short=True)

        self.static_hyperparameters = {
            to_str(k): to_str(v)
            for (k, v) in self.estimator.hyperparameters().items()
        }
        for hyperparameter_name in self._hyperparameter_ranges.keys():
            self.static_hyperparameters.pop(hyperparameter_name, None)

        # For attach() to know what estimator to use for non-1P algorithms
        # (1P algorithms don't accept extra hyperparameters)
        if not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
            self.static_hyperparameters[
                self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(
                    self.estimator.__class__.__name__)
            self.static_hyperparameters[
                self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(
                    self.estimator.__module__)
Esempio n. 4
0
    def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
                  job_name=None):
        """Start a new transform job.

        Args:
            data (str): Input data location in S3.
            data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:

                * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
                    inputs for the transform job.
                * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
                    an input for the transform job.

            content_type (str): MIME type of the input data (default: None).
            compression (str): Compression type of the input data, if compressed (default: None).
                Valid values: 'Gzip', None.
            split_type (str): The record delimiter for the input object (default: 'None').
                Valid values: 'None', 'Line', and 'RecordIO'.
            job_name (str): job name (default: None). If not specified, one will be generated.
        """
        if not data.startswith('s3://'):
            raise ValueError('Invalid S3 URI: {}'.format(data))

        if job_name is not None:
            self._current_job_name = job_name
        else:
            base_name = self.base_transform_job_name or base_name_from_image(self._retrieve_image_name())
            self._current_job_name = name_from_base(base_name)

        if self.output_path is None:
            self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name)

        self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
                                                            split_type)
Esempio n. 5
0
    def _retrieve_base_name(self):
        image_name = self._retrieve_image_name()

        if image_name:
            return base_name_from_image(image_name)

        return self.model_name
    def _prepare_for_training(self, job_name=None, include_cls_metadata=False):
        """
        Args:
            job_name:
            include_cls_metadata:
        """
        if job_name is not None:
            self._current_job_name = job_name
        else:
            base_name = self.base_tuning_job_name or base_name_from_image(
                self.estimator.train_image()
            )
            self._current_job_name = name_from_base(
                base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
            )

        self.static_hyperparameters = {
            to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items()
        }
        for hyperparameter_name in self._hyperparameter_ranges.keys():
            self.static_hyperparameters.pop(hyperparameter_name, None)

        # For attach() to know what estimator to use for frameworks
        # (other algorithms may not accept extra hyperparameters)
        if include_cls_metadata or isinstance(self.estimator, Framework):
            self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(
                self.estimator.__class__.__name__
            )
            self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(
                self.estimator.__module__
            )
Esempio n. 7
0
def prepare_framework_container_def(model, instance_type, s3_operations):
    """Prepare the framework model container information. Specify related S3
    operations for Airflow to perform. (Upload `source_dir` )

    Args:
        model (sagemaker.model.FrameworkModel): The framework model
        instance_type (str): The EC2 instance type to deploy this Model to. For
            example, 'ml.p2.xlarge'.
        s3_operations (dict): The dict to specify S3 operations (upload
            `source_dir` ).

    Returns:
        dict: The container information of this framework model.
    """
    deploy_image = model.image
    if not deploy_image:
        region_name = model.sagemaker_session.boto_session.region_name
        deploy_image = fw_utils.create_image_uri(
            region_name,
            model.__framework_name__,
            instance_type,
            model.framework_version,
            model.py_version,
        )

    base_name = utils.base_name_from_image(deploy_image)
    model.name = model.name or utils.name_from_base(base_name)

    bucket = model.bucket or model.sagemaker_session._default_bucket
    script = os.path.basename(model.entry_point)
    key = "{}/source/sourcedir.tar.gz".format(model.name)

    if model.source_dir and model.source_dir.lower().startswith("s3://"):
        code_dir = model.source_dir
        model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir,
                                                    script_name=script)
    else:
        code_dir = "s3://{}/{}".format(bucket, key)
        model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir,
                                                    script_name=script)
        s3_operations["S3Upload"] = [{
            "Path": model.source_dir or script,
            "Bucket": bucket,
            "Key": key,
            "Tar": True
        }]

    deploy_env = dict(model.env)
    deploy_env.update(model._framework_env_vars())

    try:
        if model.model_server_workers:
            deploy_env[
                sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(
                    model.model_server_workers)
    except AttributeError:
        # This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
        pass

    return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
    def _retrieve_base_name(self):
        """Placeholder docstring"""
        image_uri = self._retrieve_image_uri()

        if image_uri:
            return base_name_from_image(image_uri)

        return self.model_name
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:
                (str) - the S3 location where training data is saved.
                (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call shouldl wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """
        # always determine new job name _here_ because it is used before base is called
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(self.train_image())
            self._current_job_name = name_from_base(base_name)

        if self.code_location is None:
            code_bucket = self.sagemaker_session.default_bucket()
            code_s3_prefix = '{}/source'.format(self._current_job_name)
        else:
            code_bucket, key_prefix = parse_s3_url(self.code_location)
            code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)

        self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
                                                bucket=code_bucket,
                                                s3_key_prefix=code_s3_prefix,
                                                script=self.entry_point,
                                                directory=self.source_dir)

        # Modify hyperparameters in-place to add the URLs to the uploaded code.
        self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix
        self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name
        self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name
        super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
Esempio n. 10
0
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:

                * (str) the S3 location where training data is saved.

                * (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call should wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """

        if job_name is not None:
            self._current_job_name = job_name
        else:
            # make sure the job name is unique for each invocation, honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code',
                                          self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(
                    self.sagemaker_session.default_bucket())

        self.latest_training_job = _TrainingJob.start_new(self, inputs)
        if wait:
            self.latest_training_job.wait(logs=logs)
Esempio n. 11
0
    def fit(self, inputs):
        # ジョブ名は一意である必要がある
        from sagemaker.utils import base_name_from_image, name_from_base
        base_name = self._estimator.base_job_name or base_name_from_image(
            self._estimator.train_image())
        self._estimator._current_job_name = name_from_base(base_name)

        # アウトプットを出力する場所が指定されていない場合には,ここで指定
        if self._estimator.output_path is None:
            self._estimator.output_path = 's3://{}/'.format(
                self._estimator.sagemaker_session.default_bucket())

        from sagemaker.estimator import _TrainingJob
        self._estimator.latest_training_job = _TrainingJob.start_new(
            self._estimator, inputs)
Esempio n. 12
0
    def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
        if job_name is not None:
            self._current_job_name = job_name
        else:
            base_name = self.base_tuning_job_name or base_name_from_image(self.estimator.train_image())
            self._current_job_name = name_from_base(base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True)

        self.static_hyperparameters = {to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items()}
        for hyperparameter_name in self._hyperparameter_ranges.keys():
            self.static_hyperparameters.pop(hyperparameter_name, None)

        # For attach() to know what estimator to use for non-1P algorithms
        # (1P algorithms don't accept extra hyperparameters)
        if include_cls_metadata and not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
            self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(
                self.estimator.__class__.__name__)
            self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
Esempio n. 13
0
def model_config(instance_type, model, role=None, image=None):
    """Export Airflow model config from a SageMaker model

    Args:
        instance_type (str): The EC2 instance type to deploy this Model to. For
            example, 'ml.p2.xlarge'
        model (sagemaker.model.FrameworkModel): The SageMaker model to export
            Airflow config from
        role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
        image (str): An container image to use for deploying the model

    Returns:
        dict: Model config that can be directly used by SageMakerModelOperator
        in Airflow. It can also be part of the config used by
        SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
    """
    s3_operations = {}
    model.image = image or model.image

    if isinstance(model, sagemaker.model.FrameworkModel):
        container_def = prepare_framework_container_def(
            model, instance_type, s3_operations)
    else:
        container_def = model.prepare_container_def(instance_type)
        base_name = utils.base_name_from_image(container_def["Image"])
        model.name = model.name or utils.name_from_base(base_name)

    primary_container = session._expand_container_def(container_def)

    config = {
        "ModelName": model.name,
        "PrimaryContainer": primary_container,
        "ExecutionRoleArn": role or model.role,
    }

    if model.vpc_config:
        config["VpcConfig"] = model.vpc_config

    if s3_operations:
        config["S3Operations"] = s3_operations

    return config
    def _generate_current_job_name(self, job_name=None):
        """Generates the job name before running a processing job.

        Args:
            job_name (str): Name of the processing job to be created. If not
                specified, one is generated, using the base name given to the
                constructor if applicable.

        Returns:
            str: The supplied or generated job name.
        """
        if job_name is not None:
            return job_name
        # Honor supplied base_job_name or generate it.
        if self.base_job_name:
            base_name = self.base_job_name
        else:
            base_name = base_name_from_image(self.image_uri)

        return name_from_base(base_name)
Esempio n. 15
0
    def _prepare_for_training(self, job_name=None):
        """Set any values in the estimator that need to be set before training.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code', self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
    """Export Airflow base training config from an estimator

    Args:
        estimator (sagemaker.estimator.EstimatorBase): The estimator to export
            training config from. Can be a BYO estimator, Framework estimator or
            Amazon algorithm estimator.
        inputs: Information about the training data. Please refer to the ``fit()``
            method of
                the associated estimator, as this can take any of the following
                forms:

            * (str) - The S3 location where training data is saved.

            * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
                  channels for training data, you can specify a dict mapping channel names to
                  strings or :func:`~sagemaker.session.s3_input` objects.

            * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
                  provide additional information about the training dataset. See
                  :func:`sagemaker.session.s3_input` for full details.

            * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
                  Amazon :class:~`Record` objects serialized and stored in S3.
                  For use with an estimator for an Amazon algorithm.

            * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
                  :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
                  where each instance is a different channel of training data.
        job_name (str): Specify a training job name if needed.
        mini_batch_size (int): Specify this argument only when estimator is a
            built-in estimator of an Amazon algorithm. For other estimators,
            batch size should be specified in the estimator.

    Returns:
        dict: Training config that can be directly used by
        SageMakerTrainingOperator in Airflow.
    """
    default_bucket = estimator.sagemaker_session.default_bucket()
    s3_operations = {}

    if job_name is not None:
        estimator._current_job_name = job_name
    else:
        base_name = estimator.base_job_name or utils.base_name_from_image(estimator.train_image())
        estimator._current_job_name = utils.name_from_base(base_name)

    if estimator.output_path is None:
        estimator.output_path = "s3://{}/".format(default_bucket)

    if isinstance(estimator, sagemaker.estimator.Framework):
        prepare_framework(estimator, s3_operations)

    elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
        prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
    job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)

    train_config = {
        "AlgorithmSpecification": {
            "TrainingImage": estimator.train_image(),
            "TrainingInputMode": estimator.input_mode,
        },
        "OutputDataConfig": job_config["output_config"],
        "StoppingCondition": job_config["stop_condition"],
        "ResourceConfig": job_config["resource_config"],
        "RoleArn": job_config["role"],
    }

    if job_config["input_config"] is not None:
        train_config["InputDataConfig"] = job_config["input_config"]

    if job_config["vpc_config"] is not None:
        train_config["VpcConfig"] = job_config["vpc_config"]

    if estimator.hyperparameters() is not None:
        hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

    if hyperparameters and len(hyperparameters) > 0:
        train_config["HyperParameters"] = hyperparameters

    if s3_operations:
        train_config["S3Operations"] = s3_operations

    return train_config
def tuning_config(tuner, inputs, job_name=None):
    """Export Airflow tuning config from an estimator

    Args:
        tuner (sagemaker.tuner.HyperparameterTuner): The tuner to export tuning
            config from.
        inputs: Information about the training data. Please refer to the ``fit()``
            method of the associated estimator in the tuner, as this can take any of the
            following forms:

            * (str) - The S3 location where training data is saved.

            * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
                  channels for training data, you can specify a dict mapping channel names to
                  strings or :func:`~sagemaker.session.s3_input` objects.

            * (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
                  provide additional information about the training dataset. See
                  :func:`sagemaker.session.s3_input` for full details.

            * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
                  Amazon :class:~`Record` objects serialized and stored in S3.
                  For use with an estimator for an Amazon algorithm.

            * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
                  :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
                  where each instance is a different channel of training data.
        job_name (str): Specify a tuning job name if needed.

    Returns:
        dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
    """
    train_config = training_base_config(tuner.estimator, inputs)
    hyperparameters = train_config.pop("HyperParameters", None)
    s3_operations = train_config.pop("S3Operations", None)

    if hyperparameters and len(hyperparameters) > 0:
        tuner.static_hyperparameters = {
            utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items()
        }

    if job_name is not None:
        tuner._current_job_name = job_name
    else:
        base_name = tuner.base_tuning_job_name or utils.base_name_from_image(
            tuner.estimator.train_image()
        )
        tuner._current_job_name = utils.name_from_base(
            base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True
        )

    for hyperparameter_name in tuner._hyperparameter_ranges.keys():
        tuner.static_hyperparameters.pop(hyperparameter_name, None)

    train_config["StaticHyperParameters"] = tuner.static_hyperparameters

    tune_config = {
        "HyperParameterTuningJobName": tuner._current_job_name,
        "HyperParameterTuningJobConfig": {
            "Strategy": tuner.strategy,
            "HyperParameterTuningJobObjective": {
                "Type": tuner.objective_type,
                "MetricName": tuner.objective_metric_name,
            },
            "ResourceLimits": {
                "MaxNumberOfTrainingJobs": tuner.max_jobs,
                "MaxParallelTrainingJobs": tuner.max_parallel_jobs,
            },
            "ParameterRanges": tuner.hyperparameter_ranges(),
        },
        "TrainingJobDefinition": train_config,
    }

    if tuner.metric_definitions is not None:
        tune_config["TrainingJobDefinition"]["AlgorithmSpecification"][
            "MetricDefinitions"
        ] = tuner.metric_definitions

    if tuner.tags is not None:
        tune_config["Tags"] = tuner.tags

    if s3_operations is not None:
        tune_config["S3Operations"] = s3_operations

    return tune_config
 def _ensure_base_name_if_needed(self, image_uri):
     """Create a base name from the image URI if there is no model name provided."""
     if self.name is None:
         self._base_name = self._base_name or utils.base_name_from_image(
             image_uri)
Esempio n. 19
0
def training_config(
        estimator,
        inputs=None,
        job_name=None
):  # noqa: C901 - suppress complexity warning for this method
    """Export Airflow training config from an estimator

    Args:
        estimator (sagemaker.estimator.EstimatroBase):
            The estimator to export training config from. Can be a BYO estimator,
            Framework estimator or Amazon algorithm estimator.
        inputs (str, dict, single or list of sagemaker.amazon.amazon_estimator.RecordSet):
            The training data.
        job_name (str): Specify a training job name if needed.

    Returns:
        A dict of training config that can be directly used by SageMakerTrainingOperator
            in Airflow.
    """
    default_bucket = estimator.sagemaker_session.default_bucket()
    s3_operations = {}

    if job_name is not None:
        estimator._current_job_name = job_name
    else:
        base_name = estimator.base_job_name or utils.base_name_from_image(
            estimator.train_image())
        estimator._current_job_name = utils.airflow_name_from_base(base_name)

    if estimator.output_path is None:
        estimator.output_path = 's3://{}/'.format(default_bucket)

    if isinstance(estimator, sagemaker.estimator.Framework):
        prepare_framework(estimator, s3_operations)

    elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
        prepare_amazon_algorithm_estimator(estimator, inputs)

    job_config = job._Job._load_config(inputs,
                                       estimator,
                                       expand_role=False,
                                       validate_uri=False)

    train_config = {
        'AlgorithmSpecification': {
            'TrainingImage': estimator.train_image(),
            'TrainingInputMode': estimator.input_mode
        },
        'OutputDataConfig': job_config['output_config'],
        'TrainingJobName': estimator._current_job_name,
        'StoppingCondition': job_config['stop_condition'],
        'ResourceConfig': job_config['resource_config'],
        'RoleArn': job_config['role'],
    }

    if job_config['input_config'] is not None:
        train_config['InputDataConfig'] = job_config['input_config']

    if job_config['vpc_config'] is not None:
        train_config['VpcConfig'] = job_config['vpc_config']

    if estimator.hyperparameters() is not None:
        hyperparameters = {
            str(k): str(v)
            for (k, v) in estimator.hyperparameters().items()
        }

    if hyperparameters and len(hyperparameters) > 0:
        train_config['HyperParameters'] = hyperparameters

    if estimator.tags is not None:
        train_config['Tags'] = estimator.tags

    if s3_operations:
        train_config['S3Operations'] = s3_operations

    return train_config
Esempio n. 20
0
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:
                (str) - the S3 location where training data is saved.
                (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call shouldl wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """
        # always determine new job name _here_ because it is used before base is called
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code',
                                      self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[
            CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[
            CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[
            SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
        super(Framework, self).fit(inputs, wait, logs, self._current_job_name)