def test_create_algorithm(self):
        IMAGE_URI = os.getenv("TEST_IMAGE_URI")
        ALGORITHM_NAME = os.getenv("TEST_ALGORITHM_NAME")
        ROLE_ARN = os.getenv("TEST_ROLE_ARN")
        OUTPUT_PATH = os.getenv("TEST_OUTPUT_PATH")

        if IMAGE_URI is None:
            self.fail("Set TEST_IMAGE_URI environment variable.")
        if ALGORITHM_NAME is None:
            self.fail("Set TEST_ALGORITHM_NAME environment variable.")
        if ROLE_ARN is None:
            self.fail("Set TEST_ROLE_ARN environment variable.")
        if OUTPUT_PATH is None:
            self.fail("Set TEST_OUTPUT_PATH environment variable.")

        metrics = metrics_mod.initialize()
        hyperparameters = hpv.initialize(metrics)
        channels = cv.initialize()
        md = metadata.initialize(IMAGE_URI, hyperparameters, channels, metrics)

        client = boto3.client("sagemaker", region_name="us-west-2")
        try:
            client.delete_algorithm(AlgorithmName=ALGORITHM_NAME)
        except Exception as e:
            print(e)

        pprint.pprint(md)
        client.create_algorithm(AlgorithmName=ALGORITHM_NAME, **md)

        objective = metrics["validation:error"]
        now = datetime.now()
        dt_string = now.strftime("%Y%m%d-%H%M%S")

        client.create_hyper_parameter_tuning_job(
            HyperParameterTuningJobName="test-hpo-" + dt_string,
            HyperParameterTuningJobConfig={
                "Strategy": "Random",
                "ResourceLimits": {
                    "MaxNumberOfTrainingJobs": 6,
                    "MaxParallelTrainingJobs": 2
                },
                "HyperParameterTuningJobObjective": objective.format_tunable(),
                "ParameterRanges": hyperparameters["alpha"].format_tunable_range()
            },
            TrainingJobDefinition={
                "AlgorithmSpecification": {
                    "AlgorithmName": ALGORITHM_NAME,
                    "TrainingInputMode": "File"
                },
                "StaticHyperParameters": {"num_round": "3"},
                "RoleArn": ROLE_ARN,
                "OutputDataConfig": {"S3OutputPath": OUTPUT_PATH},
                "ResourceConfig": {"InstanceType": "ml.m5.xlarge", "InstanceCount": 1, "VolumeSizeInGB": 5},
                "StoppingCondition": {"MaxRuntimeInSeconds": 300}
                }
        )
Exemplo n.º 2
0
 def test_hyperparameters2(self):
     hyperparameters = {
         "eval_metric": "auc",
         "objective": "binary:logistic",
         "num_round": "100",
         "rate_drop": "0.3",
         "tweedie_variance_power": "1.4"
     }
     hps = hpv.initialize(self.metrics)
     hps.validate(hyperparameters)
Exemplo n.º 3
0
 def test_hyperparameters4(self):
     hyperparameters = {
         "max_depth": "5",
         "eta": "0.2",
         "gamma": "4",
         "min_child_weight": "6",
         "objective": "multi:softmax",
         "num_class": "10",
         "num_round": "10"
     }
     hps = hpv.initialize(self.metrics)
     hps.validate(hyperparameters)
Exemplo n.º 4
0
 def test_hyperparameters3(self):
     hyperparameters = {
         "max_depth": "5",
         "eta": "0.2",
         "gamma": "4",
         "min_child_weight": "6",
         "subsample": "0.7",
         "objective": "reg:squarederror",
         "num_round": "50"
     }
     hps = hpv.initialize(self.metrics)
     hps.validate(hyperparameters)
Exemplo n.º 5
0
 def test_hyperparameters(self):
     hyperparameters = {
         "max_depth": "5",
         "eta": "0.2",
         "gamma": "4",
         "min_child_weight": "6",
         "subsample": "0.8",
         "objective": "binary:logistic",
         "num_round": "100"
     }
     hps = hpv.initialize(self.metrics)
     hps.validate(hyperparameters)
Exemplo n.º 6
0
 def test_hyperparameters8(self):
     hyperparameters = {
         "max_depth": "5",
         "eta": "0.2",
         "min_split_loss": "4",
         "min_child_weight": "6",
         "tree_method": "approx",
         "objective": "multi:softmax",
         "num_class": "10",
         "num_round": "10",
         "interaction_constraints": "[[1,2,4],[3,5]]"
     }
     hps = hpv.initialize(self.metrics)
     hps.validate(hyperparameters)
Exemplo n.º 7
0
def sagemaker_train(train_config, data_config, train_path, val_path, model_dir,
                    sm_hosts, sm_current_host, checkpoint_config):
    """Train XGBoost in a SageMaker training environment.

    Validate hyperparameters and data channel using SageMaker Algorithm Toolkit to fail fast if needed.
    If running with more than one host, check if the current host has data and run train_job() using
    rabit_run.

    :param train_config:
    :param data_config:
    :param train_path:
    :param val_path:
    :param model_dir:
    :param sm_hosts:
    :param sm_current_host:
    :param checkpoint_config:
    """
    metrics = metrics_mod.initialize()

    hyperparameters = hpv.initialize(metrics)
    validated_train_config = hyperparameters.validate(train_config)
    if validated_train_config.get("updater"):
        validated_train_config["updater"] = ",".join(
            validated_train_config["updater"])

    channels = cv.initialize()
    validated_data_config = channels.validate(data_config)

    logging.debug("hyperparameters {}".format(validated_train_config))
    logging.debug("channels {}".format(validated_data_config))

    # Get Training and Validation Data Matrices
    file_type = get_content_type(
        validated_data_config['train'].get("ContentType"))
    input_mode = validated_data_config['train'].get("TrainingInputMode")
    csv_weights = validated_train_config.get("csv_weights", 0)
    is_pipe = (input_mode == Channel.PIPE_MODE)

    validation_channel = validated_data_config.get('validation', None)
    train_dmatrix, val_dmatrix = get_validated_dmatrices(
        train_path, val_path, file_type, csv_weights, is_pipe)

    checkpoint_dir = checkpoint_config.get("LocalPath", None)

    train_args = dict(train_cfg=validated_train_config,
                      train_dmatrix=train_dmatrix,
                      val_dmatrix=val_dmatrix,
                      model_dir=model_dir,
                      checkpoint_dir=checkpoint_dir)

    # Obtain information about training resources to determine whether to set up Rabit or not
    num_hosts = len(sm_hosts)

    if num_hosts > 1:
        # Wait for hosts to find each other
        logging.info("Distributed node training with {} hosts: {}".format(
            num_hosts, sm_hosts))
        distributed.wait_hostname_resolution(sm_hosts)

        if not train_dmatrix:
            logging.warning(
                "Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
                " training.".format(sm_current_host))
        distributed.rabit_run(exec_fun=train_job,
                              args=train_args,
                              include_in_training=(train_dmatrix is not None),
                              hosts=sm_hosts,
                              current_host=sm_current_host,
                              update_rabit_args=True)
    elif num_hosts == 1:
        if train_dmatrix:
            if validation_channel:
                if not val_dmatrix:
                    raise exc.UserError(
                        "No data in validation channel path {}".format(
                            val_path))
            logging.info("Single node training.")
            train_args.update({'is_master': True})
            train_job(**train_args)
        else:
            raise exc.UserError(
                "No data in training channel path {}".format(train_path))
    else:
        raise exc.PlatformError(
            "Number of hosts should be an int greater than or equal to 1")
Exemplo n.º 8
0
#     http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import unittest

from sagemaker_algorithm_toolkit import exceptions as exc

from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv
from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod

metrics = metrics_mod.initialize()
hyperparameters = hpv.initialize(metrics)


class TestHyperparameterValidation(unittest.TestCase):
    def test_auc_invalid_objective(self):
        test_hp = {'eval_metric': 'auc'}

        auc_invalid_objectives = [
            'count:poisson', 'reg:gamma', 'reg:logistic', 'reg:squarederror',
            'reg:tweedie', 'multi:softmax', 'multi:softprob', 'survival:cox'
        ]

        for invalid_objective in auc_invalid_objectives:
            test_hp['objective'] = invalid_objective

            with self.assertRaises(exc.UserError):