예제 #1
0
    def start_new(cls, estimator, inputs):
        """Create a new Amazon SageMaker training job from the estimator.

        Args:
            estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
            inputs (str): Parameters used when called  :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started
            training job.
        """

        local_mode = estimator.sagemaker_session.local_mode

        # Allow file:// input only in local mode
        if isinstance(inputs, str) and inputs.startswith('file://'):
            if not local_mode:
                raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.')

        config = _Job._load_config(inputs, estimator)

        if estimator.hyperparameters() is not None:
            hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

        estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
                                          input_config=config['input_config'], role=config['role'],
                                          job_name=estimator._current_job_name, output_config=config['output_config'],
                                          resource_config=config['resource_config'], hyperparameters=hyperparameters,
                                          stop_condition=config['stop_condition'], tags=estimator.tags)

        return cls(estimator.sagemaker_session, estimator._current_job_name)
예제 #2
0
    def start_new(cls, tuner, inputs):
        """Create a new Amazon SageMaker hyperparameter tuning job from the HyperparameterTuner.

        Args:
            tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner object created by the user.
            inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.tuner._TuningJob: Constructed object that captures all information about the started job.
        """
        config = _Job._load_config(inputs, tuner.estimator)

        tuner.estimator.sagemaker_session.tune(
            job_name=tuner._current_job_name,
            strategy=tuner.strategy,
            objective_type=tuner.objective_type,
            objective_metric_name=tuner.objective_metric_name,
            max_jobs=tuner.max_jobs,
            max_parallel_jobs=tuner.max_parallel_jobs,
            parameter_ranges=tuner.hyperparameter_ranges(),
            static_hyperparameters=tuner.static_hyperparameters,
            image=tuner.estimator.train_image(),
            input_mode=tuner.estimator.input_mode,
            metric_definitions=tuner.metric_definitions,
            role=(config['role']),
            input_config=(config['input_config']),
            output_config=(config['output_config']),
            resource_config=(config['resource_config']),
            stop_condition=(config['stop_condition']),
            tags=tuner.tags)

        return cls(tuner.sagemaker_session, tuner._current_job_name)
예제 #3
0
    def start_new(cls, estimator, inputs):
        """Create a new Amazon SageMaker training job from the estimator.

        Args:
            estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
            inputs (str): Parameters used when called  :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started
            training job.
        """

        local_mode = estimator.sagemaker_session.local_mode

        # Allow file:// input only in local mode
        if isinstance(inputs, str) and inputs.startswith('file://'):
            if not local_mode:
                raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.')

        config = _Job._load_config(inputs, estimator)

        if estimator.hyperparameters() is not None:
            hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

        estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
                                          input_config=config['input_config'], role=config['role'],
                                          job_name=estimator._current_job_name, output_config=config['output_config'],
                                          resource_config=config['resource_config'], hyperparameters=hyperparameters,
                                          stop_condition=config['stop_condition'], tags=estimator.tags)

        return cls(estimator.sagemaker_session, estimator._current_job_name)
예제 #4
0
    def start_new(cls, tuner, inputs):
        """Create a new Amazon SageMaker hyperparameter tuning job from the HyperparameterTuner.

        Args:
            tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner object created by the user.
            inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.tuner._TuningJob: Constructed object that captures all information about the started job.
        """
        config = _Job._load_config(inputs, tuner.estimator)

        tuner.estimator.sagemaker_session.tune(job_name=tuner._current_job_name, strategy=tuner.strategy,
                                               objective_type=tuner.objective_type,
                                               objective_metric_name=tuner.objective_metric_name,
                                               max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
                                               parameter_ranges=tuner.hyperparameter_ranges(),
                                               static_hyperparameters=tuner.static_hyperparameters,
                                               image=tuner.estimator.train_image(),
                                               input_mode=tuner.estimator.input_mode,
                                               metric_definitions=tuner.metric_definitions,
                                               role=(config['role']), input_config=(config['input_config']),
                                               output_config=(config['output_config']),
                                               resource_config=(config['resource_config']),
                                               stop_condition=(config['stop_condition']), tags=tuner.tags)

        return cls(tuner.sagemaker_session, tuner._current_job_name)
예제 #5
0
    def start_new(cls, tuner, inputs):
        """Create a new Amazon SageMaker hyperparameter tuning job from the
        HyperparameterTuner.

        Args:
            tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner
                object created by the user.
            inputs (str): Parameters used when called
                :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.tuner._TuningJob: Constructed object that captures all
            information about the started job.
        """
        config = _Job._load_config(inputs, tuner.estimator)

        warm_start_config_req = None
        if tuner.warm_start_config:
            warm_start_config_req = tuner.warm_start_config.to_input_req()

        tuner_args = config.copy()

        tuner_args["job_name"] = tuner._current_job_name
        tuner_args["strategy"] = tuner.strategy
        tuner_args["objective_type"] = tuner.objective_type
        tuner_args["objective_metric_name"] = tuner.objective_metric_name
        tuner_args["max_jobs"] = tuner.max_jobs
        tuner_args["max_parallel_jobs"] = tuner.max_parallel_jobs
        tuner_args["parameter_ranges"] = tuner.hyperparameter_ranges()
        tuner_args["static_hyperparameters"] = tuner.static_hyperparameters
        tuner_args["input_mode"] = tuner.estimator.input_mode
        tuner_args["metric_definitions"] = tuner.metric_definitions
        tuner_args["tags"] = tuner.tags
        tuner_args["warm_start_config"] = warm_start_config_req
        tuner_args["early_stopping_type"] = tuner.early_stopping_type

        if isinstance(inputs, s3_input):
            if "InputMode" in inputs.config:
                logging.debug(
                    "Selecting s3_input's input_mode (%s) for TrainingInputMode.",
                    inputs.config["InputMode"],
                )
                tuner_args["input_mode"] = inputs.config["InputMode"]

        if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
            tuner_args["algorithm_arn"] = tuner.estimator.algorithm_arn
        else:
            tuner_args["image"] = tuner.estimator.train_image()

        tuner_args[
            "enable_network_isolation"] = tuner.estimator.enable_network_isolation(
            )
        tuner_args[
            "encrypt_inter_container_traffic"] = tuner.estimator.encrypt_inter_container_traffic

        tuner.estimator.sagemaker_session.tune(**tuner_args)

        return cls(tuner.sagemaker_session, tuner._current_job_name)
예제 #6
0
def test_load_config(estimator):
    inputs = s3_input(BUCKET_NAME)

    config = _Job._load_config(inputs, estimator)

    assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME
    assert config['role'] == ROLE
    assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH
    assert 'KmsKeyId' not in config['output_config']
    assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT
    assert config['resource_config']['InstanceType'] == INSTANCE_TYPE
    assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE
    assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME
예제 #7
0
def test_load_config(estimator):
    inputs = s3_input(BUCKET_NAME)

    config = _Job._load_config(inputs, estimator)

    assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME
    assert config["role"] == ROLE
    assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
    assert "KmsKeyId" not in config["output_config"]
    assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
    assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
    assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE
    assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME
예제 #8
0
def test_load_config_with_code_channel_no_code_uri(framework):
    inputs = s3_input(BUCKET_NAME)

    framework.model_uri = MODEL_URI
    framework.model_channel_name = MODEL_CHANNEL_NAME
    framework._enable_network_isolation = True
    config = _Job._load_config(inputs, framework)

    assert len(config["input_config"]) == 2
    assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME
    assert config["role"] == ROLE
    assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
    assert "KmsKeyId" not in config["output_config"]
    assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
    assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
예제 #9
0
def test_load_config_with_model_channel_no_inputs(estimator):
    estimator.model_uri = MODEL_URI
    estimator.model_channel_name = MODEL_CHANNEL_NAME

    config = _Job._load_config(inputs=None, estimator=estimator)

    assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI
    assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME
    assert config["role"] == ROLE
    assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
    assert "KmsKeyId" not in config["output_config"]
    assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
    assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
    assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE
    assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME
예제 #10
0
def test_load_config_with_model_channel_no_inputs(estimator):
    estimator.model_uri = MODEL_URI
    estimator.model_channel_name = CHANNEL_NAME

    config = _Job._load_config(inputs=None, estimator=estimator)

    assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI
    assert config['input_config'][0]['ChannelName'] == CHANNEL_NAME
    assert config['role'] == ROLE
    assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH
    assert 'KmsKeyId' not in config['output_config']
    assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT
    assert config['resource_config']['InstanceType'] == INSTANCE_TYPE
    assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE
    assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME
예제 #11
0
    def start_new(cls, tuner, inputs):
        """Create a new Amazon SageMaker hyperparameter tuning job from the HyperparameterTuner.

        Args:
            tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner object created by the user.
            inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.

        Returns:
            sagemaker.tuner._TuningJob: Constructed object that captures all information about the started job.
        """
        config = _Job._load_config(inputs, tuner.estimator)

        warm_start_config_req = None
        if tuner.warm_start_config:
            warm_start_config_req = tuner.warm_start_config.to_input_req()

        tuner_args = config.copy()

        tuner_args['job_name'] = tuner._current_job_name
        tuner_args['strategy'] = tuner.strategy
        tuner_args['objective_type'] = tuner.objective_type
        tuner_args['objective_metric_name'] = tuner.objective_metric_name
        tuner_args['max_jobs'] = tuner.max_jobs
        tuner_args['max_parallel_jobs'] = tuner.max_parallel_jobs
        tuner_args['parameter_ranges'] = tuner.hyperparameter_ranges()
        tuner_args['static_hyperparameters'] = tuner.static_hyperparameters
        tuner_args['input_mode'] = tuner.estimator.input_mode
        tuner_args['metric_definitions'] = tuner.metric_definitions
        tuner_args['tags'] = tuner.tags
        tuner_args['warm_start_config'] = warm_start_config_req
        tuner_args['early_stopping_type'] = tuner.early_stopping_type

        if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
            tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn
        else:
            tuner_args['image'] = tuner.estimator.train_image()

        tuner_args[
            'enable_network_isolation'] = tuner.estimator.enable_network_isolation(
            )
        tuner_args['encrypt_inter_container_traffic'] = \
            tuner.estimator.encrypt_inter_container_traffic

        tuner.estimator.sagemaker_session.tune(**tuner_args)

        return cls(tuner.sagemaker_session, tuner._current_job_name)