Ejemplo n.º 1
0
 def test_PlatformError_ValueError(self):
     e = exc.PlatformError("Test 123", caused_by=ValueError("abc"))
     self.assertEqual(e.message, "Test 123 (caused by ValueError)")
Ejemplo n.º 2
0
def sagemaker_train(train_config, data_config, train_path, val_path, model_dir,
                    sm_hosts, sm_current_host, checkpoint_config):
    """Train XGBoost in a SageMaker training environment.

    Validate hyperparameters and data channel using SageMaker Algorithm Toolkit to fail fast if needed.
    If running with more than one host, check if the current host has data and run train_job() using
    rabit_run.

    :param train_config:
    :param data_config:
    :param train_path:
    :param val_path:
    :param model_dir:
    :param sm_hosts:
    :param sm_current_host:
    :param checkpoint_config:
    """
    metrics = metrics_mod.initialize()

    hyperparameters = hpv.initialize(metrics)
    validated_train_config = hyperparameters.validate(train_config)
    if validated_train_config.get("updater"):
        validated_train_config["updater"] = ",".join(
            validated_train_config["updater"])

    channels = cv.initialize()
    validated_data_config = channels.validate(data_config)

    logging.debug("hyperparameters {}".format(validated_train_config))
    logging.debug("channels {}".format(validated_data_config))

    # Get Training and Validation Data Matrices
    file_type = get_content_type(
        validated_data_config['train'].get("ContentType"))
    input_mode = validated_data_config['train'].get("TrainingInputMode")
    csv_weights = validated_train_config.get("csv_weights", 0)
    is_pipe = (input_mode == Channel.PIPE_MODE)

    validation_channel = validated_data_config.get('validation', None)
    train_dmatrix, val_dmatrix = get_validated_dmatrices(
        train_path, val_path, file_type, csv_weights, is_pipe)

    checkpoint_dir = checkpoint_config.get("LocalPath", None)

    train_args = dict(train_cfg=validated_train_config,
                      train_dmatrix=train_dmatrix,
                      val_dmatrix=val_dmatrix,
                      model_dir=model_dir,
                      checkpoint_dir=checkpoint_dir)

    # Obtain information about training resources to determine whether to set up Rabit or not
    num_hosts = len(sm_hosts)

    if num_hosts > 1:
        # Wait for hosts to find each other
        logging.info("Distributed node training with {} hosts: {}".format(
            num_hosts, sm_hosts))
        distributed.wait_hostname_resolution(sm_hosts)

        if not train_dmatrix:
            logging.warning(
                "Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
                " training.".format(sm_current_host))
        distributed.rabit_run(exec_fun=train_job,
                              args=train_args,
                              include_in_training=(train_dmatrix is not None),
                              hosts=sm_hosts,
                              current_host=sm_current_host,
                              update_rabit_args=True)
    elif num_hosts == 1:
        if train_dmatrix:
            if validation_channel:
                if not val_dmatrix:
                    raise exc.UserError(
                        "No data in validation channel path {}".format(
                            val_path))
            logging.info("Single node training.")
            train_args.update({'is_master': True})
            train_job(**train_args)
        else:
            raise exc.UserError(
                "No data in training channel path {}".format(train_path))
    else:
        raise exc.PlatformError(
            "Number of hosts should be an int greater than or equal to 1")
Ejemplo n.º 3
0
 def test_PlatformError(self):
     e = exc.PlatformError("Test 123")
     self.assertEqual(e.message, "Test 123")