Exemple #1
0
    def _prepare_for_training(self, job_name=None):
        """Set hyperparameters needed for training. This method will also validate ``source_dir``.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        super(Framework, self)._prepare_for_training(job_name=job_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code',
                                      self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[
            CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[
            CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[
            SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
    def _prepare_for_training(self, job_name=None):
        """Set hyperparameters needed for training. This method will also validate ``source_dir``.

        Args:
            * job_name (str): Name of the training job to be created. If not specified, one is generated,
                using the base name given to the constructor if applicable.
        """
        super(Framework, self)._prepare_for_training(job_name=job_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code', self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
Exemple #3
0
def test_validate_source_dir_file_not_in_dir():
    script = " !@#$%^&*() .myscript. !@#$%^&*() "
    directory = "."
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)
Exemple #4
0
def test_validate_source_dir_is_not_directory(sagemaker_session):
    script = "mnist.py"
    directory = inspect.getfile(inspect.currentframe())
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)
Exemple #5
0
def test_validate_source_dir_does_not_exits(sagemaker_session):
    script = "mnist.py"
    directory = " !@#$%^&*()path probably in not there.!@#$%^&*()"
    with pytest.raises(ValueError):
        fw_utils.validate_source_dir(script, directory)
def test_validate_source_dir_file_not_in_dir():
    script = ' !@#$%^&*() .myscript. !@#$%^&*() '
    directory = '.'
    with pytest.raises(ValueError):
        validate_source_dir(script, directory)
def test_validate_source_dir_does_not_exits(sagemaker_session):
    script = 'mnist.py'
    directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()'
    with pytest.raises(ValueError):
        validate_source_dir(script, directory)
def test_validate_source_dir_does_not_exits(sagemaker_session):
    script = 'mnist.py'
    directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()'
    with pytest.raises(ValueError):
        validate_source_dir(script, directory)
def test_validate_source_dir_file_not_in_dir():
    script = ' !@#$%^&*() .myscript. !@#$%^&*() '
    directory = '.'
    with pytest.raises(ValueError):
        validate_source_dir(script, directory)
def test_validate_source_dir_is_not_directory(sagemaker_session):
    script = 'mnist.py'
    directory = inspect.getfile(inspect.currentframe())
    with pytest.raises(ValueError):
        validate_source_dir(script, directory)
    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:
                (str) - the S3 location where training data is saved.
                (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call shouldl wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """
        # always determine new job name _here_ because it is used before base is called
        if job_name is not None:
            self._current_job_name = job_name
        else:
            # honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(
                self.train_image())
            self._current_job_name = name_from_base(base_name)

        # validate source dir will raise a ValueError if there is something wrong with the
        # source directory. We are intentionally not handling it because this is a critical error.
        if self.source_dir and not self.source_dir.lower().startswith('s3://'):
            validate_source_dir(self.entry_point, self.source_dir)

        # if we are in local mode with local_code=True. We want the container to just
        # mount the source dir instead of uploading to S3.
        local_code = get_config_value('local.local_code',
                                      self.sagemaker_session.config)
        if self.sagemaker_session.local_mode and local_code:
            # if there is no source dir, use the directory containing the entry point.
            if self.source_dir is None:
                self.source_dir = os.path.dirname(self.entry_point)
            self.entry_point = os.path.basename(self.entry_point)

            code_dir = 'file://' + self.source_dir
            script = self.entry_point
        else:
            self.uploaded_code = self._stage_user_code_in_s3()
            code_dir = self.uploaded_code.s3_prefix
            script = self.uploaded_code.script_name

        # Modify hyperparameters in-place to point to the right code directory and script URIs
        self._hyperparameters[DIR_PARAM_NAME] = code_dir
        self._hyperparameters[SCRIPT_PARAM_NAME] = script
        self._hyperparameters[
            CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
        self._hyperparameters[
            CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
        self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
        self._hyperparameters[
            SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
        super(Framework, self).fit(inputs, wait, logs, self._current_job_name)