Esempio n. 1
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier
        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run":
            self._run_identifier,
            "stop":
            stopping_criteria,
            "config":
            config,
            "resources_per_trial":
            resources_per_trial,
            "num_samples":
            num_samples,
            "local_dir":
            os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir":
            upload_dir,
            "remote_checkpoint_dir":
            self.remote_checkpoint_dir,
            "trial_name_creator":
            trial_name_creator,
            "loggers":
            loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver":
            sync_to_driver,
            "checkpoint_freq":
            checkpoint_freq,
            "checkpoint_at_end":
            checkpoint_at_end,
            "sync_on_checkpoint":
            sync_on_checkpoint,
            "keep_checkpoints_num":
            keep_checkpoints_num,
            "checkpoint_score_attr":
            checkpoint_score_attr,
            "export_formats":
            export_formats or [],
            "max_failures":
            max_failures,
            "restore":
            os.path.abspath(os.path.expanduser(restore)) if restore else None
        }
        self.spec = spec
Esempio n. 2
0
def DistributedTrainableCreator(func,
                                use_gpu=False,
                                num_workers=1,
                                num_cpus_per_worker=1,
                                backend="gloo",
                                timeout_s=NCCL_TIMEOUT_S):
    """Creates a class that executes distributed training.

    Similar to running `torch.distributed.launch`.

    Note that you typically should not instantiate the object
    created.

    Args:
        func (callable): This function is a Tune trainable function.
            This function must have 2 args in the signature, and the
            latter arg must contain `checkpoint_dir`. For example:
            `func(config, checkpoint_dir=None)`.
        use_gpu (bool): Sets resource allocation for workers to 1 GPU
            if true. Also automatically sets CUDA_VISIBLE_DEVICES
            for each training worker.
        num_workers (int): Number of training workers to include in
            world.
        num_cpus_per_worker (int): Number of CPU resources to reserve
            per training worker.
        backend (str): One of "gloo", "nccl".
        timeout_s (float): Seconds before the torch process group
            times out. Useful when machines are unreliable. Defaults
            to 60 seconds.

    Returns:
        A trainable class object that can be passed to Tune. Resources
            are automatically set within the object, so users do
            not need to set `resources_per_trainable`.

    Example:

    .. code-block:: python

        trainable_cls = DistributedTrainableCreator(
            train_func, num_workers=2)
        analysis = tune.run(trainable_cls)
    """
    detect_checkpoint_function(func, abort=True)

    class WrappedDistributedTorchTrainable(_TorchTrainable):
        _function = func
        _num_workers = num_workers
        _use_gpu = use_gpu
        _num_cpus_per_worker = num_cpus_per_worker

        @classmethod
        def default_process_group_parameters(self):
            return dict(timeout=timedelta(timeout_s), backend=backend)

        @classmethod
        def default_resource_request(cls, config):
            num_workers_ = int(config.get("num_workers", num_workers))
            num_cpus = int(
                config.get("num_cpus_per_worker", num_cpus_per_worker))
            use_gpu_ = config.get("use_gpu", use_gpu)

            return Resources(cpu=0,
                             gpu=0,
                             extra_cpu=num_cpus * num_workers_,
                             extra_gpu=num_workers_ if use_gpu_ else 0)

    return WrappedDistributedTorchTrainable