コード例 #1
0
    def _create_docker_host(self, host, environment, optml_subdirs, command, volumes):
        optml_volumes = self._build_optml_volumes(host, optml_subdirs)
        optml_volumes.extend(volumes)

        host_config = {
            'image': self.image,
            'stdin_open': True,
            'tty': True,
            'volumes': [v.map for v in optml_volumes],
            'environment': environment,
            'command': command,
            'networks': {
                'sagemaker-local': {
                    'aliases': [host]
                }
            }
        }

        if command == 'serve':
            serving_port = get_config_value('local.serving_port',
                                            self.sagemaker_session.config) or 8080
            host_config.update({
                'ports': [
                    '%s:8080' % serving_port
                ]
            })

        return host_config
コード例 #2
0
    def _prepare_for_training(self, job_name=None):
        """Set any values in the estimator that need to be set before training.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code',
                                          self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(
                    self.sagemaker_session.default_bucket())
コード例 #3
0
    def _create_docker_host(self, host, environment, optml_subdirs, command,
                            volumes):
        optml_volumes = self._build_optml_volumes(host, optml_subdirs)
        optml_volumes.extend(volumes)

        host_config = {
            'image': self.image,
            'stdin_open': True,
            'tty': True,
            'volumes': [v.map for v in optml_volumes],
            'environment': environment,
            'command': command,
            'networks': {
                'sagemaker-local': {
                    'aliases': [host]
                }
            }
        }

        if command == 'serve':
            serving_port = get_config_value(
                'local.serving_port', self.sagemaker_session.config) or 8080
            host_config.update({'ports': ['%s:8080' % serving_port]})

        return host_config
コード例 #4
0
    def create_endpoint(self, EndpointName, EndpointConfigName):
        instance_type = self.variants[0]['InstanceType']
        instance_count = self.variants[0]['InitialInstanceCount']
        self.serve_container = _SageMakerContainer(
            instance_type, instance_count, self.primary_container['Image'],
            self.sagemaker_session)
        self.serve_container.serve(self.primary_container)
        self.created_endpoint = True

        i = 0
        http = urllib3.PoolManager()
        serving_port = get_config_value('local.serving_port',
                                        self.sagemaker_session.config) or 8080
        endpoint_url = "http://localhost:%s/ping" % serving_port
        while True:
            i += 1
            if i >= 10:
                raise RuntimeError(
                    "Giving up, endpoint: %s didn't launch correctly" %
                    EndpointName)

            logger.info("Checking if endpoint is up, attempt: %s" % i)
            try:
                r = http.request('GET', endpoint_url)
                if r.status != 200:
                    logger.info("Container still not up, got: %s" % r.status)
                else:
                    return
            except urllib3.exceptions.RequestError:
                logger.info("Container still not up")

            time.sleep(1)
コード例 #5
0
    def _upload_code(self, key_prefix, repack=False):
        local_code = utils.get_config_value("local.local_code",
                                            self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            self.uploaded_code = None
        elif not repack:
            bucket = self.bucket or self.sagemaker_session.default_bucket()
            self.uploaded_code = fw_utils.tar_and_upload_dir(
                session=self.sagemaker_session.boto_session,
                bucket=bucket,
                s3_key_prefix=key_prefix,
                script=self.entry_point,
                directory=self.source_dir,
                dependencies=self.dependencies,
            )

        if repack:
            bucket = self.bucket or self.sagemaker_session.default_bucket()
            repacked_model_data = "s3://" + os.path.join(
                bucket, key_prefix, "model.tar.gz")

            utils.repack_model(
                inference_script=self.entry_point,
                source_directory=self.source_dir,
                dependencies=self.dependencies,
                model_uri=self.model_data,
                repacked_model_uri=repacked_model_data,
                sagemaker_session=self.sagemaker_session,
            )

            self.repacked_model_data = repacked_model_data
            self.uploaded_code = UploadedCode(
                s3_prefix=self.repacked_model_data,
                script_name=os.path.basename(self.entry_point))
コード例 #6
0
def test_get_config_value():

    config = {
        'local': {
            'region_name': 'us-west-2',
            'port': '123'
        },
        'other': {
            'key': 1
        }
    }

    assert get_config_value('local.region_name', config) == 'us-west-2'
    assert get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'}

    assert get_config_value('does_not.exist', config) is None
    assert get_config_value('other.key', None) is None
コード例 #7
0
 def _default_s3_path(self, directory):
     local_code = get_config_value('local.local_code',
                                   self.sagemaker_session.config)
     if self.sagemaker_session.local_mode and local_code:
         return '/opt/ml/shared/{}'.format(directory)
     else:
         return os.path.join(self.output_path, self._current_job_name,
                             directory)
コード例 #8
0
def test_get_config_value():

    config = {
        'local': {
            'region_name': 'us-west-2',
            'port': '123'
        },
        'other': {
            'key': 1
        }
    }

    assert get_config_value('local.region_name', config) == 'us-west-2'
    assert get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'}

    assert get_config_value('does_not.exist', config) is None
    assert get_config_value('other.key', None) is None
コード例 #9
0
 def _default_s3_path(self, directory, mpi=False):
     local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
     if self.sagemaker_session.local_mode and local_code:
         return "/opt/ml/shared/{}".format(directory)
     if mpi:
         return "/opt/ml/model"
     if self._current_job_name:
         return os.path.join(self.output_path, self._current_job_name, directory)
     return None
コード例 #10
0
    def _get_working_directory(self):
        # Root dir to use for intermediate data location. To make things simple we will write here regardless
        # of the final destination. At the end the files will either be moved or uploaded to S3 and deleted.
        root_dir = get_config_value('local.container_root', self.local_session.config)
        if root_dir:
            root_dir = os.path.abspath(root_dir)

        working_dir = tempfile.mkdtemp(dir=root_dir)
        return working_dir
コード例 #11
0
ファイル: model.py プロジェクト: w601sxs/serverless-sagemaker
 def _upload_code(self, key_prefix):
     local_code = get_config_value('local.local_code', self.sagemaker_session.config)
     if self.sagemaker_session.local_mode and local_code:
         self.uploaded_code = None
     else:
         self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
                                                 bucket=self.bucket or self.sagemaker_session.default_bucket(),
                                                 s3_key_prefix=key_prefix,
                                                 script=self.entry_point,
                                                 directory=self.source_dir)
コード例 #12
0
 def _upload_code(self, key_prefix):
     local_code = get_config_value('local.local_code', self.sagemaker_session.config)
     if self.sagemaker_session.local_mode and local_code:
         self.uploaded_code = None
     else:
         self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
                                                 bucket=self.bucket or self.sagemaker_session.default_bucket(),
                                                 s3_key_prefix=key_prefix,
                                                 script=self.entry_point,
                                                 directory=self.source_dir)
コード例 #13
0
    def __init__(self, config=None):
        """Initializes a LocalSageMakerRuntimeClient

        Args:
            config (dict): Optional configuration for this client. In particular only
                the local port is read.
        """
        self.http = urllib3.PoolManager()
        self.serving_port = 8080
        self.config = config
        self.serving_port = get_config_value('local.serving_port', config) or 8080
コード例 #14
0
    def start(self, input_data, output_data, transform_resources, **kwargs):
        """Start the Local Transform Job

        Args:
            input_data (dict): Describes the dataset to be transformed and the
                location where it is stored.
            output_data (dict): Identifies the location where to save the
                results from the transform job
            transform_resources (dict): compute instances for the transform job.
                Currently only supports local or local_gpu
            **kwargs: additional arguments coming from the boto request object
        """
        self.transform_resources = transform_resources
        self.input_data = input_data
        self.output_data = output_data

        image = self.primary_container["Image"]
        instance_type = transform_resources["InstanceType"]
        instance_count = 1

        environment = self._get_container_environment(**kwargs)

        # Start the container, pass the environment and wait for it to start up
        self.container = _SageMakerContainer(instance_type, instance_count,
                                             image, self.local_session)
        self.container.serve(self.primary_container["ModelDataUrl"],
                             environment)

        serving_port = get_config_value("local.serving_port",
                                        self.local_session.config) or 8080
        _wait_for_serving_container(serving_port)

        # Get capabilities from Container if needed
        endpoint_url = "http://localhost:%s/execution-parameters" % serving_port
        response, code = _perform_request(endpoint_url)
        if code == 200:
            execution_parameters = json.loads(response.read())
            # MaxConcurrentTransforms is ignored because we currently only support 1
            for setting in ("BatchStrategy", "MaxPayloadInMB"):
                if setting not in kwargs and setting in execution_parameters:
                    kwargs[setting] = execution_parameters[setting]

        # Apply Defaults if none was provided
        kwargs.update(self._get_required_defaults(**kwargs))

        self.start_time = datetime.datetime.now()
        self.batch_strategy = kwargs["BatchStrategy"]
        if "Environment" in kwargs:
            self.environment = kwargs["Environment"]

        # run the batch inference requests
        self._perform_batch_inference(input_data, output_data, **kwargs)
        self.end_time = datetime.datetime.now()
        self.state = self._COMPLETED
コード例 #15
0
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:

                * (str) the S3 location where training data is saved.

                * (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call should wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """

        if job_name is not None:
            self._current_job_name = job_name
        else:
            # make sure the job name is unique for each invocation, honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code',
                                          self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(
                    self.sagemaker_session.default_bucket())

        self.latest_training_job = _TrainingJob.start_new(self, inputs)
        if wait:
            self.latest_training_job.wait(logs=logs)
コード例 #16
0
    def _create_tmp_folder(self):
        root_dir = get_config_value('local.container_root', self.sagemaker_session.config)
        if root_dir:
            root_dir = os.path.abspath(root_dir)

        dir = tempfile.mkdtemp(dir=root_dir)

        # Docker cannot mount Mac OS /var folder properly see
        # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600
        # Only apply this workaround if the user didn't provide an alternate storage root dir.
        if root_dir is None and platform.system() == 'Darwin':
            dir = '/private{}'.format(dir)

        return os.path.abspath(dir)
コード例 #17
0
ファイル: image.py プロジェクト: mklissa/sagemaker-python-sdk
    def _create_tmp_folder(self):
        root_dir = get_config_value('local.container_root', self.sagemaker_session.config)
        if root_dir:
            root_dir = os.path.abspath(root_dir)

        dir = tempfile.mkdtemp(dir=root_dir)

        # Docker cannot mount Mac OS /var folder properly see
        # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600
        # Only apply this workaround if the user didn't provide an alternate storage root dir.
        if root_dir is None and platform.system() == 'Darwin':
            dir = '/private{}'.format(dir)

        return os.path.abspath(dir)
コード例 #18
0
    def serve(self):
        image = self.primary_container['Image']
        instance_type = self.production_variant['InstanceType']
        instance_count = self.production_variant['InitialInstanceCount']

        self.create_time = datetime.datetime.now()
        self.container = _SageMakerContainer(instance_type, instance_count,
                                             image, self.local_session)
        self.container.serve(self.primary_container['ModelDataUrl'],
                             self.primary_container['Environment'])

        serving_port = get_config_value('local.serving_port',
                                        self.local_session.config) or 8080
        _wait_for_serving_container(serving_port)
        # the container is running and it passed the healthcheck status is now InService
        self.state = _LocalEndpoint._IN_SERVICE
コード例 #19
0
    def hyperparameters(self):
        """Return hyperparameters used by your custom TensorFlow code during model training."""
        hyperparameters = super(MXNet, self).hyperparameters()

        if not self.checkpoint_path:
            local_code = get_config_value('local.local_code', self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.checkpoint_path = '/opt/ml/shared/checkpoints'
            else:
                self.checkpoint_path = os.path.join(self. ,
                                                    self._current_job_name, 'checkpoints')


        additional_hyperparameters = {'checkpoint_path': self.checkpoint_path}

        hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
        return hyperparameters            
コード例 #20
0
    def __init__(self, config=None):
        """Initializes a LocalSageMakerRuntimeClient

        Args:
            config (dict): Optional configuration for this client. In particular only
                the local port is read.
        """
        try:
            import urllib3
        except ImportError as e:
            logging.error(_module_import_error("urllib3", "Local mode", "local"))
            raise e

        self.http = urllib3.PoolManager()
        self.serving_port = 8080
        self.config = config
        self.serving_port = get_config_value("local.serving_port", config) or 8080
コード例 #21
0
    def serve(self):
        image = self.primary_container['Image']
        instance_type = self.production_variant['InstanceType']
        instance_count = self.production_variant['InitialInstanceCount']

        accelerator_type = self.production_variant.get('AcceleratorType')
        if accelerator_type == 'local_sagemaker_notebook':
            self.primary_container['Environment']['SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'] = 'true'

        self.create_time = datetime.datetime.now()
        self.container = _SageMakerContainer(instance_type, instance_count, image, self.local_session)
        self.container.serve(self.primary_container['ModelDataUrl'], self.primary_container['Environment'])

        serving_port = get_config_value('local.serving_port', self.local_session.config) or 8080
        _wait_for_serving_container(serving_port)
        # the container is running and it passed the healthcheck status is now InService
        self.state = _LocalEndpoint._IN_SERVICE
コード例 #22
0
    def hyperparameters(self):
        """Return hyperparameters used by your custom TensorFlow code during model training."""
        hyperparameters = super(TensorFlow, self).hyperparameters()

        if not self.checkpoint_path:
            local_code = get_config_value('local.local_code', self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.checkpoint_path = '/opt/ml/shared/checkpoints'
            else:
                self.checkpoint_path = os.path.join(self.output_path,
                                                    self._current_job_name, 'checkpoints')

        additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
                                      'training_steps': self.training_steps,
                                      'evaluation_steps': self.evaluation_steps,
                                      'sagemaker_requirements': self.requirements_file}

        hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
        return hyperparameters
コード例 #23
0
    def _prepare_for_training(self, job_name=None):
        """Set hyperparameters needed for training. This method will also validate ``source_dir``.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        super(Framework, self)._prepare_for_training(job_name=job_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code',
                                      self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[
            CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[
            CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[
            SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
コード例 #24
0
    def serve(self):
        image = self.primary_container["Image"]
        instance_type = self.production_variant["InstanceType"]
        instance_count = self.production_variant["InitialInstanceCount"]

        accelerator_type = self.production_variant.get("AcceleratorType")
        if accelerator_type == "local_sagemaker_notebook":
            self.primary_container["Environment"][
                "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"] = "true"

        self.create_time = datetime.datetime.now()
        self.container = _SageMakerContainer(instance_type, instance_count,
                                             image, self.local_session)
        self.container.serve(self.primary_container["ModelDataUrl"],
                             self.primary_container["Environment"])

        serving_port = get_config_value("local.serving_port",
                                        self.local_session.config) or 8080
        _wait_for_serving_container(serving_port)
        # the container is running and it passed the healthcheck status is now InService
        self.state = _LocalEndpoint._IN_SERVICE
コード例 #25
0
    def _prepare_for_training(self, job_name=None):
        """Set any values in the estimator that need to be set before training.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here.
        # For Local Mode with local_code=True we don't need an explicit output_path
        if self.output_path is None:
            local_code = get_config_value('local.local_code', self.sagemaker_session.config)
            if self.sagemaker_session.local_mode and local_code:
                self.output_path = ''
            else:
                self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
コード例 #26
0
    def _prepare_for_training(self, job_name=None):
        """Set hyperparameters needed for training. This method will also validate ``source_dir``.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        super(Framework, self)._prepare_for_training(job_name=job_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code', self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
コード例 #27
0
    def serve(self):
        image = self.primary_container['Image']
        instance_type = self.production_variant['InstanceType']
        instance_count = self.production_variant['InitialInstanceCount']

        self.create_time = datetime.datetime.now()
        self.container = _SageMakerContainer(instance_type, instance_count,
                                             image, self.local_session)
        self.container.serve(self.primary_container['ModelDataUrl'],
                             self.primary_container['Environment'])

        i = 0
        http = urllib3.PoolManager()
        serving_port = get_config_value('local.serving_port',
                                        self.local_session.config) or 8080
        endpoint_url = 'http://localhost:%s/ping' % serving_port
        while True:
            i += 1
            if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
                self.state = _LocalEndpoint._FAILED
                raise RuntimeError(
                    'Giving up, endpoint: %s didn\'t launch correctly' %
                    self.name)

            logger.info('Checking if endpoint is up, attempt: %s' % i)
            try:
                r = http.request('GET', endpoint_url)
                if r.status != 200:
                    logger.info('Container still not up, got: %s' % r.status)
                else:
                    # the container is running and it passed the healthcheck status is now InService
                    self.state = _LocalEndpoint._IN_SERVICE
                    return
            except urllib3.exceptions.RequestError:
                logger.info('Container still not up')

            time.sleep(1)
コード例 #28
0
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:
                (str) - the S3 location where training data is saved.
                (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call shouldl wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """
        # always determine new job name _here_ because it is used before base is called
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code',
                                      self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[
            CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[
            CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[
            SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
        super(Framework, self).fit(inputs, wait, logs, self._current_job_name)