class RayTrialExecutor(TrialExecutor):
    """An implementation of TrialExecutor based on Ray."""
    def __init__(self,
                 queue_trials=False,
                 reuse_actors=False,
                 ray_auto_init=None,
                 refresh_period=None):
        if ray_auto_init is None:
            if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
                logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
                ray_auto_init = False
            else:
                ray_auto_init = True

        super(RayTrialExecutor, self).__init__(queue_trials)
        # Check for if we are launching a trial without resources in kick off
        # autoscaler.
        self._trial_queued = False
        self._running = {}
        # Since trial resume after paused should not run
        # trial.train.remote(), thus no more new remote object ref generated.
        # We use self._paused to store paused trials here.
        self._paused = {}

        self._trial_cleanup = _TrialCleanup()
        self._reuse_actors = reuse_actors
        self._cached_actor = None

        self._avail_resources = Resources(cpu=0, gpu=0)
        self._committed_resources = Resources(cpu=0, gpu=0)
        self._resources_initialized = False

        if refresh_period is None:
            refresh_period = float(
                os.environ.get("TUNE_STATE_REFRESH_PERIOD",
                               TUNE_STATE_REFRESH_PERIOD))
        self._refresh_period = refresh_period
        self._last_resource_refresh = float("-inf")
        self._last_ip_refresh = float("-inf")
        self._last_ip_addresses = set()
        self._last_nontrivial_wait = time.time()
        if not ray.is_initialized() and ray_auto_init:
            logger.info("Initializing Ray automatically."
                        "For cluster usage or custom Ray initialization, "
                        "call `ray.init(...)` before `tune.run`.")
            ray.init()

        if ray.is_initialized():
            self._update_avail_resources()

    def _setup_remote_runner(self, trial, reuse_allowed):
        trial.init_logdir()
        # We checkpoint metadata here to try mitigating logdir duplication
        self.try_checkpoint_metadata(trial)
        logger_creator = partial(noop_logger_creator, logdir=trial.logdir)

        if (self._reuse_actors and reuse_allowed
                and self._cached_actor is not None):
            logger.debug("Trial %s: Reusing cached runner %s", trial,
                         self._cached_actor)
            existing_runner = self._cached_actor
            self._cached_actor = None
            trial.set_runner(existing_runner)
            if not self.reset_trial(trial, trial.config, trial.experiment_tag,
                                    logger_creator):
                raise AbortTrialExecution(
                    "Trainable runner reuse requires reset_config() to be "
                    "implemented and return True.")
            return existing_runner

        if self._cached_actor:
            logger.debug("Cannot reuse cached runner {} for new trial".format(
                self._cached_actor))
            with self._change_working_directory(trial):
                self._trial_cleanup.add(trial, actor=self._cached_actor)
            self._cached_actor = None

        _actor_cls = _class_cache.get(trial.get_trainable_cls())
        full_actor_class = _actor_cls.options(
            num_cpus=trial.resources.cpu,
            num_gpus=trial.resources.gpu,
            memory=trial.resources.memory or None,
            object_store_memory=trial.resources.object_store_memory or None,
            resources=trial.resources.custom_resources)
        # Clear the Trial's location (to be updated later on result)
        # since we don't know where the remote runner is placed.
        trial.set_location(Location())
        logger.debug("Trial %s: Setting up new remote runner.", trial)
        # Logging for trials is handled centrally by TrialRunner, so
        # configure the remote runner to use a noop-logger.
        trial_config = copy.deepcopy(trial.config)
        trial_config[TRIAL_INFO] = TrialInfo(trial)

        stdout_file, stderr_file = trial.log_to_file
        trial_config[STDOUT_FILE] = stdout_file
        trial_config[STDERR_FILE] = stderr_file
        kwargs = {
            "config": trial_config,
            "logger_creator": logger_creator,
        }
        if issubclass(trial.get_trainable_cls(), DurableTrainable):
            kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir

        with self._change_working_directory(trial):
            return full_actor_class.remote(**kwargs)

    def _train(self, trial):
        """Start one iteration of training and save remote id."""
        if self._find_item(self._paused, trial):
            raise TuneError(
                "Should not call `train` on PAUSED trial {}. "
                "This is an internal error - please file an issue "
                "on https://github.com/ray-project/ray/issues/.".format(
                    str(trial)))

        if self._find_item(self._running, trial):
            logging.debug(
                "Trial {} already has a queued future. Skipping this "
                "`train` call. This may occur if a trial has "
                "been unpaused within a scheduler callback.".format(
                    str(trial)))
            return

        assert trial.status == Trial.RUNNING, trial.status
        with self._change_working_directory(trial):
            remote = trial.runner.train.remote()

        # Local Mode
        if isinstance(remote, dict):
            remote = _LocalWrapper(remote)

        self._running[remote] = trial
        trial_item = self._find_item(self._running, trial)
        assert len(trial_item) < 2, trial_item

    def _start_trial(self, trial, checkpoint=None, runner=None, train=True):
        """Starts trial and restores last result if trial was paused.

        Args:
            trial (Trial): The trial to start.
            checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
                If None, and no trial checkpoint exists, the trial is started
                from the beginning.
            runner (Trainable): The remote runner to use. This can be the
                cached actor. If None, a new runner is created.
            train (bool): Whether or not to start training.

        See `RayTrialExecutor.restore` for possible errors raised.
        """
        prior_status = trial.status
        if runner is None:
            # We reuse actors when there is previously instantiated state on
            # the actor. Function API calls are also supported when there is
            # no checkpoint to continue from.
            # TODO: Check preconditions - why is previous state needed?
            reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
                            or issubclass(trial.get_trainable_cls(),
                                          FunctionRunner)
            runner = self._setup_remote_runner(trial, reuse_allowed)
        trial.set_runner(runner)
        self.restore(trial, checkpoint)
        self.set_status(trial, Trial.RUNNING)

        previous_run = self._find_item(self._paused, trial)
        if prior_status == Trial.PAUSED and previous_run:
            # If Trial was in flight when paused, self._paused stores result.
            self._paused.pop(previous_run[0])
            self._running[previous_run[0]] = trial
        elif train and not trial.is_restoring:
            self._train(trial)

    def _stop_trial(self, trial, error=False, error_msg=None):
        """Stops this trial.

        Stops this trial, releasing all allocating resources. If stopping the
        trial fails, the run will be marked as terminated in error, but no
        exception will be thrown.

        Args:
            error (bool): Whether to mark this trial as terminated in error.
            error_msg (str): Optional error message.
        """
        self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
        trial.set_location(Location())

        try:
            trial.write_error_log(error_msg)
            if hasattr(trial, "runner") and trial.runner:
                if (not error and self._reuse_actors
                        and self._cached_actor is None):
                    logger.debug("Reusing actor for %s", trial.runner)
                    self._cached_actor = trial.runner
                else:
                    logger.debug("Trial %s: Destroying actor.", trial)
                    with self._change_working_directory(trial):
                        self._trial_cleanup.add(trial, actor=trial.runner)
        except Exception:
            logger.exception("Trial %s: Error stopping runner.", trial)
            self.set_status(trial, Trial.ERROR)
        finally:
            trial.set_runner(None)

    def start_trial(self, trial, checkpoint=None, train=True):
        """Starts the trial.

        Will not return resources if trial repeatedly fails on start.

        Args:
            trial (Trial): Trial to be started.
            checkpoint (Checkpoint): A Python object or path storing the state
                of trial.
            train (bool): Whether or not to start training.
        """
        self._commit_resources(trial.resources)
        try:
            self._start_trial(trial, checkpoint, train=train)
        except AbortTrialExecution:
            logger.exception("Trial %s: Error starting runner, aborting!",
                             trial)
            time.sleep(2)
            error_msg = traceback.format_exc()
            self._stop_trial(trial, error=True, error_msg=error_msg)
        except Exception:
            logger.exception("Trial %s: Unexpected error starting runner.",
                             trial)
            time.sleep(2)
            error_msg = traceback.format_exc()
            self._stop_trial(trial, error=True, error_msg=error_msg)
            # Note that we don't return the resources, since they may
            # have been lost. TODO(ujvl): is this the right thing to do?

    def _find_item(self, dictionary, item):
        out = [rid for rid, t in dictionary.items() if t is item]
        return out

    def stop_trial(self, trial, error=False, error_msg=None):
        """Only returns resources if resources allocated."""
        prior_status = trial.status
        self._stop_trial(trial, error=error, error_msg=error_msg)
        if prior_status == Trial.RUNNING:
            logger.debug("Trial %s: Returning resources.", trial)
            self._return_resources(trial.resources)
            out = self._find_item(self._running, trial)
            for result_id in out:
                self._running.pop(result_id)

    def continue_training(self, trial):
        """Continues the training of this trial."""
        self._train(trial)

    def pause_trial(self, trial):
        """Pauses the trial.

        If trial is in-flight, preserves return value in separate queue
        before pausing, which is restored when Trial is resumed.
        """
        trial_future = self._find_item(self._running, trial)
        if trial_future:
            self._paused[trial_future[0]] = trial
        super(RayTrialExecutor, self).pause_trial(trial)

    def reset_trial(self,
                    trial,
                    new_config,
                    new_experiment_tag,
                    logger_creator=None):
        """Tries to invoke `Trainable.reset()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial trainable.
            new_experiment_tag (str): New experiment name for trial.
            logger_creator (Optional[Callable[[Dict], Logger]]): Function
                that instantiates a logger on the actor process.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.set_experiment_tag(new_experiment_tag)
        trial.set_config(new_config)
        trainable = trial.runner
        with self._change_working_directory(trial):
            with warn_if_slow("reset"):
                try:
                    reset_val = ray.get(trainable.reset.remote(
                        new_config, logger_creator),
                                        timeout=DEFAULT_GET_TIMEOUT)
                except GetTimeoutError:
                    logger.exception("Trial %s: reset timed out.", trial)
                    return False
        return reset_val

    def get_running_trials(self):
        """Returns the running trials."""
        return list(self._running.values())

    def get_alive_node_ips(self):
        now = time.time()
        if now - self._last_ip_refresh < self._refresh_period:
            return self._last_ip_addresses
        logger.debug("Checking ips from Ray state.")
        self._last_ip_refresh = now
        nodes = ray.state.nodes()
        ip_addresses = set()
        for node in nodes:
            if node["alive"]:
                ip_addresses.add(node["NodeManagerAddress"])
        self._last_ip_addresses = ip_addresses
        return ip_addresses

    def get_current_trial_ips(self):
        return {t.node_ip for t in self.get_running_trials()}

    def get_next_failed_trial(self):
        """Gets the first trial found to be running on a node presumed dead.

        Returns:
            A Trial object that is ready for failure processing. None if
            no failure detected.
        """
        if ray.worker._mode() != ray.worker.LOCAL_MODE:
            live_cluster_ips = self.get_alive_node_ips()
            if live_cluster_ips - self.get_current_trial_ips():
                for trial in self.get_running_trials():
                    if trial.node_ip and trial.node_ip not in live_cluster_ips:
                        return trial
        return None

    def get_next_available_trial(self):
        shuffled_results = list(self._running.keys())
        random.shuffle(shuffled_results)
        # Note: We shuffle the results because `ray.wait` by default returns
        # the first available result, and we want to guarantee that slower
        # trials (i.e. trials that run remotely) also get fairly reported.
        # See https://github.com/ray-project/ray/issues/4211 for details.
        start = time.time()
        [result_id], _ = ray.wait(shuffled_results)
        wait_time = time.time() - start
        if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
            self._last_nontrivial_wait = time.time()
        if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
            logger.warning(
                "Over the last {} seconds, the Tune event loop has been "
                "backlogged processing new results. Consider increasing your "
                "period of result reporting to improve performance.".format(
                    BOTTLENECK_WARN_PERIOD_S))

            self._last_nontrivial_wait = time.time()
        return self._running[result_id]

    def fetch_result(self, trial):
        """Fetches one result of the running trials.

        Returns:
            Result of the most recent trial training run.
        """
        trial_future = self._find_item(self._running, trial)
        if not trial_future:
            raise ValueError("Trial was not running.")
        self._running.pop(trial_future[0])
        with warn_if_slow("fetch_result"):
            result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)

        # For local mode
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()
        return result

    def _commit_resources(self, resources):
        committed = self._committed_resources
        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) + resources.get_res_total(k)
            for k in all_keys
        }

        self._committed_resources = Resources(
            committed.cpu + resources.cpu_total(),
            committed.gpu + resources.gpu_total(),
            committed.memory + resources.memory_total(),
            committed.object_store_memory +
            resources.object_store_memory_total(),
            custom_resources=custom_resources)

    def _return_resources(self, resources):
        committed = self._committed_resources

        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) - resources.get_res_total(k)
            for k in all_keys
        }
        self._committed_resources = Resources(
            committed.cpu - resources.cpu_total(),
            committed.gpu - resources.gpu_total(),
            custom_resources=custom_resources)

        assert self._committed_resources.is_nonnegative(), (
            "Resource invalid: {}".format(resources))

    def _update_avail_resources(self, num_retries=5):
        if time.time() - self._last_resource_refresh < self._refresh_period:
            return
        logger.debug("Checking Ray cluster resources.")
        resources = None
        for i in range(num_retries):
            if i > 0:
                logger.warning(
                    "Cluster resources not detected or are 0. Attempt #"
                    "%s...", i + 1)
                time.sleep(0.5)
            try:
                resources = ray.cluster_resources()
            except Exception as exc:
                # TODO(rliaw): Remove this when local mode is fixed.
                # https://github.com/ray-project/ray/issues/4147
                logger.debug(f"{exc}: Using resources for local machine.")
                resources = ResourceSpec().resolve(True).to_resource_dict()
            if resources:
                break

        if not resources:
            # NOTE: This hides the possibility that Ray may be waiting for
            # clients to connect.
            resources.setdefault("CPU", 0)
            resources.setdefault("GPU", 0)
            logger.warning("Cluster resources cannot be detected or are 0. "
                           "You can resume this experiment by passing in "
                           "`resume=True` to `run`.")

        resources = resources.copy()
        num_cpus = resources.pop("CPU", 0)
        num_gpus = resources.pop("GPU", 0)
        memory = ray_constants.from_memory_units(resources.pop("memory", 0))
        object_store_memory = ray_constants.from_memory_units(
            resources.pop("object_store_memory", 0))
        custom_resources = resources

        self._avail_resources = Resources(
            int(num_cpus),
            int(num_gpus),
            memory=int(memory),
            object_store_memory=int(object_store_memory),
            custom_resources=custom_resources)
        self._last_resource_refresh = time.time()
        self._resources_initialized = True

    def has_resources(self, resources):
        """Returns whether this runner has at least the specified resources.

        This refreshes the Ray cluster resources if the time since last update
        has exceeded self._refresh_period. This also assumes that the
        cluster is not resizing very frequently.
        """
        self._update_avail_resources()
        currently_available = Resources.subtract(self._avail_resources,
                                                 self._committed_resources)

        have_space = (
            resources.cpu_total() <= currently_available.cpu
            and resources.gpu_total() <= currently_available.gpu
            and resources.memory_total() <= currently_available.memory
            and resources.object_store_memory_total() <=
            currently_available.object_store_memory and all(
                resources.get_res_total(res) <= currently_available.get(res)
                for res in resources.custom_resources))

        if have_space:
            # The assumption right now is that we block all trials if one
            # trial is queued.
            self._trial_queued = False
            return True

        can_overcommit = self._queue_trials and not self._trial_queued
        if can_overcommit:
            self._trial_queued = True
            logger.warning(
                "Allowing trial to start even though the "
                "cluster does not have enough free resources. Trial actors "
                "may appear to hang until enough resources are added to the "
                "cluster (e.g., via autoscaling). You can disable this "
                "behavior by specifying `queue_trials=False` in "
                "ray.tune.run().")
            return True

        return False

    def debug_string(self):
        """Returns a human readable message for printing to the console."""
        if self._resources_initialized:
            status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, "
                      "{}/{} GiB heap, {}/{} GiB objects".format(
                          self._committed_resources.cpu,
                          self._avail_resources.cpu,
                          self._committed_resources.gpu,
                          self._avail_resources.gpu,
                          _to_gb(self._committed_resources.memory),
                          _to_gb(self._avail_resources.memory),
                          _to_gb(
                              self._committed_resources.object_store_memory),
                          _to_gb(self._avail_resources.object_store_memory)))
            customs = ", ".join([
                "{}/{} {}".format(
                    self._committed_resources.get_res_total(name),
                    self._avail_resources.get_res_total(name), name)
                for name in self._avail_resources.custom_resources
                if not name.startswith(ray.resource_spec.NODE_ID_PREFIX)
            ])
            if customs:
                status += " ({})".format(customs)
            return status
        else:
            return "Resources requested: ?"

    def resource_string(self):
        """Returns a string describing the total resources available."""
        if self._resources_initialized:
            res_str = ("{} CPUs, {} GPUs, "
                       "{} GiB heap, {} GiB objects".format(
                           self._avail_resources.cpu,
                           self._avail_resources.gpu,
                           _to_gb(self._avail_resources.memory),
                           _to_gb(self._avail_resources.object_store_memory)))
            if self._avail_resources.custom_resources:
                custom = ", ".join(
                    "{} {}".format(self._avail_resources.get_res_total(name),
                                   name)
                    for name in self._avail_resources.custom_resources)
                res_str += " ({})".format(custom)
            return res_str
        else:
            return "? CPUs, ? GPUs"

    def on_step_begin(self, trial_runner):
        """Before step() called, update the available resources."""
        self._update_avail_resources()

    def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
        """Saves the trial's state to a checkpoint asynchronously.

        Args:
            trial (Trial): The trial to be saved.
            storage (str): Where to store the checkpoint. Defaults to
                PERSISTENT.
            result (dict): The state of this trial as a dictionary to be saved.
                If result is None, the trial's last result will be used.

        Returns:
             Checkpoint object, or None if an Exception occurs.
        """
        result = result or trial.last_result
        with self._change_working_directory(trial):
            if storage == Checkpoint.MEMORY:
                value = trial.runner.save_to_object.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.on_checkpoint(checkpoint)
            else:
                value = trial.runner.save.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.saving_to = checkpoint
                self._running[value] = trial
        return checkpoint

    def restore(self, trial, checkpoint=None, block=False):
        """Restores training state from a given model checkpoint.

        Args:
            trial (Trial): The trial to be restored.
            checkpoint (Checkpoint): The checkpoint to restore from. If None,
                the most recent PERSISTENT checkpoint is used. Defaults to
                None.
            block (bool): Whether or not to block on restore before returning.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if issubclass(trial.get_trainable_cls(),
                          DurableTrainable) or not trial.sync_on_checkpoint:
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where a DurableTrainable is not provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                obj = TrainableUtil.checkpoint_to_object(value)
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` and a Trainable "
                    "extending `DurableTrainable` for remote storage-based "
                    "restoration")

            if block:
                ray.get(remote)
            else:
                self._running[remote] = trial
                trial.restoring_from = checkpoint

    def export_trial_if_needed(self, trial):
        """Exports model of this trial based on trial.export_formats.

        Return:
            A dict that maps ExportFormats to successfully exported models.
        """
        if trial.export_formats and len(trial.export_formats) > 0:
            with self._change_working_directory(trial):
                return ray.get(trial.runner.export_model.remote(
                    trial.export_formats),
                               timeout=DEFAULT_GET_TIMEOUT)
        return {}

    def has_gpus(self):
        if self._resources_initialized:
            self._update_avail_resources()
            return self._avail_resources.gpu > 0

    def cleanup(self):
        self._trial_cleanup.cleanup(partial=False)

    @contextmanager
    def _change_working_directory(self, trial):
        """Context manager changing working directory to trial logdir.
        Used in local mode.

        For non-local mode it is no-op.
        """
        if ray.worker._mode() == ray.worker.LOCAL_MODE:
            old_dir = os.getcwd()
            try:
                os.chdir(trial.logdir)
                yield
            finally:
                os.chdir(old_dir)
        else:
            yield
示例#2
0
class RayTrialExecutor(TrialExecutor):
    """An implemention of TrialExecutor based on Ray."""
    def __init__(self,
                 queue_trials=False,
                 reuse_actors=False,
                 ray_auto_init=False,
                 refresh_period=RESOURCE_REFRESH_PERIOD):
        super(RayTrialExecutor, self).__init__(queue_trials)
        # Check for if we are launching a trial without resources in kick off
        # autoscaler.
        self._trial_queued = False
        self._running = {}
        # Since trial resume after paused should not run
        # trial.train.remote(), thus no more new remote object id generated.
        # We use self._paused to store paused trials here.
        self._paused = {}
        self._reuse_actors = reuse_actors
        self._cached_actor = None

        self._avail_resources = Resources(cpu=0, gpu=0)
        self._committed_resources = Resources(cpu=0, gpu=0)
        self._resources_initialized = False
        self._refresh_period = refresh_period
        self._last_resource_refresh = float("-inf")
        self._last_nontrivial_wait = time.time()
        if not ray.is_initialized() and ray_auto_init:
            logger.info("Initializing Ray automatically."
                        "For cluster usage or custom Ray initialization, "
                        "call `ray.init(...)` before `tune.run`.")
            ray.init()

        if ray.is_initialized():
            self._update_avail_resources()

    def _setup_runner(self, trial, reuse_allowed):
        if (self._reuse_actors and reuse_allowed
                and self._cached_actor is not None):
            logger.debug("Reusing cached runner {} for {}".format(
                self._cached_actor, trial.trial_id))
            existing_runner = self._cached_actor
            self._cached_actor = None
        else:
            if self._cached_actor:
                logger.debug(
                    "Cannot reuse cached runner {} for new trial".format(
                        self._cached_actor))
                self._cached_actor.stop.remote()
                self._cached_actor.__ray_terminate__.remote()
                self._cached_actor = None
            existing_runner = None
            cls = ray.remote(
                num_cpus=trial.resources.cpu,
                num_gpus=trial.resources.gpu,
                memory=trial.resources.memory,
                object_store_memory=trial.resources.object_store_memory,
                resources=trial.resources.custom_resources)(
                    trial._get_trainable_cls())

        trial.init_logger()
        # We checkpoint metadata here to try mitigating logdir duplication
        self.try_checkpoint_metadata(trial)
        remote_logdir = trial.logdir

        if existing_runner:
            trial.runner = existing_runner
            if not self.reset_trial(trial, trial.config, trial.experiment_tag):
                raise AbortTrialExecution(
                    "Trainable runner reuse requires reset_config() to be "
                    "implemented and return True.")
            return existing_runner

        def logger_creator(config):
            # Set the working dir in the remote process, for user file writes
            if not os.path.exists(remote_logdir):
                os.makedirs(remote_logdir)
            if not ray.worker._mode() == ray.worker.LOCAL_MODE:
                os.chdir(remote_logdir)
            return NoopLogger(config, remote_logdir)

        # Logging for trials is handled centrally by TrialRunner, so
        # configure the remote runner to use a noop-logger.
        return cls.remote(config=trial.config, logger_creator=logger_creator)

    def _train(self, trial):
        """Start one iteration of training and save remote id."""

        assert trial.status == Trial.RUNNING, trial.status
        remote = trial.runner.train.remote()

        # Local Mode
        if isinstance(remote, dict):
            remote = _LocalWrapper(remote)

        self._running[remote] = trial

    def _start_trial(self, trial, checkpoint=None):
        """Starts trial and restores last result if trial was paused.

        Raises:
            ValueError if restoring from checkpoint fails.
        """
        prior_status = trial.status
        self.set_status(trial, Trial.RUNNING)
        trial.runner = self._setup_runner(
            trial,
            reuse_allowed=checkpoint is not None
            or trial._checkpoint.value is not None)
        if not self.restore(trial, checkpoint):
            if trial.status == Trial.ERROR:
                raise RuntimeError(
                    "Restore from checkpoint failed for Trial {}.".format(
                        str(trial)))

        previous_run = self._find_item(self._paused, trial)
        if (prior_status == Trial.PAUSED and previous_run):
            # If Trial was in flight when paused, self._paused stores result.
            self._paused.pop(previous_run[0])
            self._running[previous_run[0]] = trial
        else:
            self._train(trial)

    def _stop_trial(self,
                    trial,
                    error=False,
                    error_msg=None,
                    stop_logger=True):
        """Stops this trial.

        Stops this trial, releasing all allocating resources. If stopping the
        trial fails, the run will be marked as terminated in error, but no
        exception will be thrown.

        Args:
            error (bool): Whether to mark this trial as terminated in error.
            error_msg (str): Optional error message.
            stop_logger (bool): Whether to shut down the trial logger.
        """

        if stop_logger:
            trial.close_logger()

        if error:
            self.set_status(trial, Trial.ERROR)
        else:
            self.set_status(trial, Trial.TERMINATED)

        try:
            trial.write_error_log(error_msg)
            if hasattr(trial, "runner") and trial.runner:
                if (not error and self._reuse_actors
                        and self._cached_actor is None):
                    logger.debug("Reusing actor for {}".format(trial.runner))
                    self._cached_actor = trial.runner
                else:
                    logger.debug(
                        "Destroying actor for trial {}.".format(trial))
                    trial.runner.stop.remote()
                    trial.runner.__ray_terminate__.remote()
        except Exception:
            logger.exception("Error stopping runner for Trial %s", str(trial))
            self.set_status(trial, Trial.ERROR)
        finally:
            trial.runner = None

    def start_trial(self, trial, checkpoint=None):
        """Starts the trial.

        Will not return resources if trial repeatedly fails on start.

        Args:
            trial (Trial): Trial to be started.
            checkpoint (Checkpoint): A Python object or path storing the state
                of trial.
        """

        self._commit_resources(trial.resources)
        try:
            self._start_trial(trial, checkpoint)
        except Exception as e:
            logger.exception("Error starting runner for Trial %s", str(trial))
            error_msg = traceback.format_exc()
            time.sleep(2)
            self._stop_trial(trial, error=True, error_msg=error_msg)
            if isinstance(e, AbortTrialExecution):
                return  # don't retry fatal Tune errors
            try:
                # This forces the trial to not start from checkpoint.
                trial.clear_checkpoint()
                logger.info(
                    "Trying to start runner for Trial %s without checkpoint.",
                    str(trial))
                self._start_trial(trial)
            except Exception:
                logger.exception(
                    "Error starting runner for Trial %s, aborting!",
                    str(trial))
                error_msg = traceback.format_exc()
                self._stop_trial(trial, error=True, error_msg=error_msg)
                # note that we don't return the resources, since they may
                # have been lost

    def _find_item(self, dictionary, item):
        out = [rid for rid, t in dictionary.items() if t is item]
        return out

    def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
        """Only returns resources if resources allocated."""
        prior_status = trial.status
        self._stop_trial(trial,
                         error=error,
                         error_msg=error_msg,
                         stop_logger=stop_logger)
        if prior_status == Trial.RUNNING:
            logger.debug("Returning resources for Trial %s.", str(trial))
            self._return_resources(trial.resources)
            out = self._find_item(self._running, trial)
            for result_id in out:
                self._running.pop(result_id)

    def continue_training(self, trial):
        """Continues the training of this trial."""

        self._train(trial)

    def pause_trial(self, trial):
        """Pauses the trial.

        If trial is in-flight, preserves return value in separate queue
        before pausing, which is restored when Trial is resumed.
        """

        trial_future = self._find_item(self._running, trial)
        if trial_future:
            self._paused[trial_future[0]] = trial
        super(RayTrialExecutor, self).pause_trial(trial)

    def reset_trial(self, trial, new_config, new_experiment_tag):
        """Tries to invoke `Trainable.reset_config()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial
                trainable.
            new_experiment_tag (str): New experiment name
                for trial.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with warn_if_slow("reset_config"):
            reset_val = ray.get(trainable.reset_config.remote(new_config))
        return reset_val

    def get_running_trials(self):
        """Returns the running trials."""

        return list(self._running.values())

    def get_alive_node_ips(self):
        nodes = ray.state.nodes()
        ip_addresses = set()
        for node in nodes:
            if node["alive"]:
                ip_addresses.add(node["NodeManagerAddress"])
        return ip_addresses

    def get_current_trial_ips(self):
        return {t.node_ip for t in self.get_running_trials()}

    def get_next_failed_trial(self):
        """Gets the first trial found to be running on a node presumed dead.

        Returns:
            A Trial object that is ready for failure processing. None if
            no failure detected.
        """
        if ray.worker._mode() != ray.worker.LOCAL_MODE:
            live_cluster_ips = self.get_alive_node_ips()
            if live_cluster_ips - self.get_current_trial_ips():
                for trial in self.get_running_trials():
                    if trial.node_ip and trial.node_ip not in live_cluster_ips:
                        return trial
        return None

    def get_next_available_trial(self):
        shuffled_results = list(self._running.keys())
        random.shuffle(shuffled_results)
        # Note: We shuffle the results because `ray.wait` by default returns
        # the first available result, and we want to guarantee that slower
        # trials (i.e. trials that run remotely) also get fairly reported.
        # See https://github.com/ray-project/ray/issues/4211 for details.
        start = time.time()
        [result_id], _ = ray.wait(shuffled_results)
        wait_time = time.time() - start
        if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
            self._last_nontrivial_wait = time.time()
        if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
            logger.warning(
                "Over the last {} seconds, the Tune event loop has been "
                "backlogged processing new results. Consider increasing your "
                "period of result reporting to improve performance.".format(
                    BOTTLENECK_WARN_PERIOD_S))

            self._last_nontrivial_wait = time.time()
        return self._running[result_id]

    def fetch_result(self, trial):
        """Fetches one result of the running trials.

        Returns:
            Result of the most recent trial training run."""
        trial_future = self._find_item(self._running, trial)
        if not trial_future:
            raise ValueError("Trial was not running.")
        self._running.pop(trial_future[0])
        with warn_if_slow("fetch_result"):
            result = ray.get(trial_future[0])

        # For local mode
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()
        return result

    def _commit_resources(self, resources):
        committed = self._committed_resources
        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) + resources.get_res_total(k)
            for k in all_keys
        }

        self._committed_resources = Resources(
            committed.cpu + resources.cpu_total(),
            committed.gpu + resources.gpu_total(),
            committed.memory + resources.memory_total(),
            committed.object_store_memory +
            resources.object_store_memory_total(),
            custom_resources=custom_resources)

    def _return_resources(self, resources):
        committed = self._committed_resources

        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) - resources.get_res_total(k)
            for k in all_keys
        }
        self._committed_resources = Resources(
            committed.cpu - resources.cpu_total(),
            committed.gpu - resources.gpu_total(),
            custom_resources=custom_resources)

        assert self._committed_resources.is_nonnegative(), (
            "Resource invalid: {}".format(resources))

    def _update_avail_resources(self, num_retries=5):
        for i in range(num_retries):
            try:
                resources = ray.cluster_resources()
            except Exception:
                # TODO(rliaw): Remove this when local mode is fixed.
                # https://github.com/ray-project/ray/issues/4147
                logger.debug("Using resources for local machine.")
                resources = ResourceSpec().resolve(True).to_resource_dict()
            if not resources:
                logger.warning(
                    "Cluster resources not detected or are 0. Retrying...")
                time.sleep(0.5)

        if not resources:
            # NOTE: This hides the possibility that Ray may be waiting for
            # clients to connect.
            resources.setdefault("CPU", 0)
            resources.setdefault("GPU", 0)
            logger.warning("Cluster resources cannot be detected or are 0. "
                           "You can resume this experiment by passing in "
                           "`resume=True` to `run`.")

        resources = resources.copy()
        num_cpus = resources.pop("CPU", 0)
        num_gpus = resources.pop("GPU", 0)
        memory = ray_constants.from_memory_units(resources.pop("memory", 0))
        object_store_memory = ray_constants.from_memory_units(
            resources.pop("object_store_memory", 0))
        custom_resources = resources

        self._avail_resources = Resources(
            int(num_cpus),
            int(num_gpus),
            memory=int(memory),
            object_store_memory=int(object_store_memory),
            custom_resources=custom_resources)
        self._last_resource_refresh = time.time()
        self._resources_initialized = True

    def has_resources(self, resources):
        """Returns whether this runner has at least the specified resources.

        This refreshes the Ray cluster resources if the time since last update
        has exceeded self._refresh_period. This also assumes that the
        cluster is not resizing very frequently.
        """
        if time.time() - self._last_resource_refresh > self._refresh_period:
            self._update_avail_resources()

        currently_available = Resources.subtract(self._avail_resources,
                                                 self._committed_resources)

        have_space = (
            resources.cpu_total() <= currently_available.cpu
            and resources.gpu_total() <= currently_available.gpu
            and resources.memory_total() <= currently_available.memory
            and resources.object_store_memory_total() <=
            currently_available.object_store_memory and all(
                resources.get_res_total(res) <= currently_available.get(res)
                for res in resources.custom_resources))

        if have_space:
            # The assumption right now is that we block all trials if one
            # trial is queued.
            self._trial_queued = False
            return True

        can_overcommit = self._queue_trials and not self._trial_queued
        if can_overcommit:
            self._trial_queued = True
            logger.warning(
                "Allowing trial to start even though the "
                "cluster does not have enough free resources. Trial actors "
                "may appear to hang until enough resources are added to the "
                "cluster (e.g., via autoscaling). You can disable this "
                "behavior by specifying `queue_trials=False` in "
                "ray.tune.run().")
            return True

        return False

    def debug_string(self):
        """Returns a human readable message for printing to the console."""

        if self._resources_initialized:
            status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, "
                      "{}/{} GiB heap, {}/{} GiB objects".format(
                          self._committed_resources.cpu,
                          self._avail_resources.cpu,
                          self._committed_resources.gpu,
                          self._avail_resources.gpu,
                          _to_gb(self._committed_resources.memory),
                          _to_gb(self._avail_resources.memory),
                          _to_gb(
                              self._committed_resources.object_store_memory),
                          _to_gb(self._avail_resources.object_store_memory)))
            customs = ", ".join([
                "{}/{} {}".format(
                    self._committed_resources.get_res_total(name),
                    self._avail_resources.get_res_total(name), name)
                for name in self._avail_resources.custom_resources
            ])
            if customs:
                status += " ({})".format(customs)
            return status
        else:
            return "Resources requested: ?"

    def resource_string(self):
        """Returns a string describing the total resources available."""

        if self._resources_initialized:
            res_str = ("{} CPUs, {} GPUs, "
                       "{} GiB heap, {} GiB objects".format(
                           self._avail_resources.cpu,
                           self._avail_resources.gpu,
                           _to_gb(self._avail_resources.memory),
                           _to_gb(self._avail_resources.object_store_memory)))
            if self._avail_resources.custom_resources:
                custom = ", ".join(
                    "{} {}".format(self._avail_resources.get_res_total(name),
                                   name)
                    for name in self._avail_resources.custom_resources)
                res_str += " ({})".format(custom)
            return res_str
        else:
            return "? CPUs, ? GPUs"

    def on_step_begin(self, trial_runner):
        """Before step() called, update the available resources."""
        self._update_avail_resources()

    def save(self, trial, storage=Checkpoint.DISK):
        """Saves the trial's state to a checkpoint."""
        trial._checkpoint.storage = storage
        trial._checkpoint.last_result = trial.last_result
        if storage == Checkpoint.MEMORY:
            trial._checkpoint.value = trial.runner.save_to_object.remote()
        else:
            # Keeps only highest performing checkpoints if enabled
            if trial.keep_checkpoints_num:
                try:
                    last_attr_val = trial.last_result[
                        trial.checkpoint_score_attr]
                    if (trial.compare_checkpoints(last_attr_val)
                            and not math.isnan(last_attr_val)):
                        trial.best_checkpoint_attr_value = last_attr_val
                        self._checkpoint_and_erase(trial)
                except KeyError:
                    logger.warning(
                        "Result dict has no key: {}. keep"
                        "_checkpoints_num flag will not work".format(
                            trial.checkpoint_score_attr))
            else:
                with warn_if_slow("save_to_disk"):
                    trial._checkpoint.value = ray.get(
                        trial.runner.save.remote())

        return trial._checkpoint.value

    def _checkpoint_and_erase(self, trial):
        """Checkpoints the model and erases old checkpoints
            if needed.
        Parameters
        ----------
            trial : trial to save
        """

        with warn_if_slow("save_to_disk"):
            trial._checkpoint.value = ray.get(trial.runner.save.remote())

        if len(trial.history) >= trial.keep_checkpoints_num:
            ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
            trial.history.pop()

        trial.history.insert(0, trial._checkpoint.value)

    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        This will also sync the trial results to a new location
        if restoring on a different node.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial._checkpoint
        if checkpoint is None or checkpoint.value is None:
            return True
        if trial.runner is None:
            logger.error("Unable to restore - no runner.")
            self.set_status(trial, Trial.ERROR)
            return False
        try:
            value = checkpoint.value
            if checkpoint.storage == Checkpoint.MEMORY:
                assert type(value) != Checkpoint, type(value)
                trial.runner.restore_from_object.remote(value)
            else:
                # TODO: Somehow, the call to get the current IP on the
                # remote actor can be very slow - a better fix would
                # be to use an actor table to detect the IP of the Trainable
                # and rsync the files there.
                # See https://github.com/ray-project/ray/issues/5168
                with warn_if_slow("get_current_ip"):
                    worker_ip = ray.get(trial.runner.current_ip.remote())
                with warn_if_slow("sync_to_new_location"):
                    trial.sync_logger_to_new_location(worker_ip)
                with warn_if_slow("restore_from_disk"):
                    ray.get(trial.runner.restore.remote(value))
            trial.last_result = checkpoint.last_result
            return True
        except Exception:
            logger.exception("Error restoring runner for Trial %s.", trial)
            self.set_status(trial, Trial.ERROR)
            return False

    def export_trial_if_needed(self, trial):
        """Exports model of this trial based on trial.export_formats.

        Return:
            A dict that maps ExportFormats to successfully exported models.
        """
        if trial.export_formats and len(trial.export_formats) > 0:
            return ray.get(
                trial.runner.export_model.remote(trial.export_formats))
        return {}
示例#3
0
class RayTrialExecutor(TrialExecutor):
    """An implementation of TrialExecutor based on Ray."""
    def __init__(self,
                 queue_trials: bool = False,
                 reuse_actors: bool = False,
                 result_buffer_length: Optional[int] = None,
                 refresh_period: Optional[float] = None,
                 wait_for_placement_group: Optional[float] = None):
        super(RayTrialExecutor, self).__init__(queue_trials)
        # Check for if we are launching a trial without resources in kick off
        # autoscaler.
        self._trial_queued = False
        self._running = {}
        # Since trial resume after paused should not run
        # trial.train.remote(), thus no more new remote object ref generated.
        # We use self._paused to store paused trials here.
        self._paused = {}

        self._trial_cleanup = _TrialCleanup()
        self._has_cleaned_up_pgs = False
        self._reuse_actors = reuse_actors
        self._cached_actor_pg = (None, None)

        self._avail_resources = Resources(cpu=0, gpu=0)
        self._committed_resources = Resources(cpu=0, gpu=0)
        self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
        self._staged_trials = set()
        self._just_staged_trials = set()
        self._trial_just_finished = False
        self._trial_just_finished_before = False

        self._resources_initialized = False

        if refresh_period is None:
            refresh_period = float(
                os.environ.get("TUNE_STATE_REFRESH_PERIOD",
                               TUNE_STATE_REFRESH_PERIOD))
        self._refresh_period = refresh_period

        self._wait_for_pg = wait_for_placement_group or float(
            os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S", "-1"))
        if self._wait_for_pg < 0:
            self._wait_for_pg = None

        self.last_pg_recon = 0
        self.pg_recon_interval = float(
            os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5"))

        self._default_buffer_length = result_buffer_length or int(
            os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1000))
        self._buffer_length = result_buffer_length

        self._buffer_min_time_s = float(
            os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.))
        self._buffer_max_time_s = float(
            os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.))

        self._last_resource_refresh = float("-inf")
        self._last_ip_refresh = float("-inf")
        self._last_ip_addresses = set()
        self._last_nontrivial_wait = time.time()

        if ray.is_initialized():
            self._update_avail_resources()

    def in_staging_grace_period(self) -> bool:
        """Returns True if trials have recently been staged."""
        return self._pg_manager.in_staging_grace_period()

    def set_max_pending_trials(self, max_pending: int) -> None:
        self._pg_manager.set_max_staging(max_pending)

    def stage_and_update_status(self, trials: Iterable[Trial]):
        """Check and update statuses of scheduled placement groups.

        Stages placement groups of all trials.
        """
        if not self._has_cleaned_up_pgs:
            # Clean up existing placement groups after trigger the tuning
            # run step() method for the first time
            self._pg_manager.cleanup_existing_pg()
            self._has_cleaned_up_pgs = True

        for trial in trials:
            if trial.status != Trial.PENDING:
                continue
            if not trial.uses_placement_groups:
                continue
            if trial in self._staged_trials:
                continue
            if self._pg_manager.trial_in_use(trial):
                continue

            if not self._pg_manager.stage_trial_pg(trial):
                # Break if we reached the limit of pending placement groups.
                break
            self._staged_trials.add(trial)
            self._just_staged_trials.add(trial)

        self._pg_manager.update_status()

    def get_staged_trial(self):
        """Get a trial whose placement group was successfully staged.

        Can also return None if no trial is available.

        Returns:
            Trial object or None.

        """
        for trial in self._staged_trials:
            if self._pg_manager.has_ready(trial):
                return trial

        return None

    def _setup_remote_runner(self, trial):
        trial.init_logdir()
        # We checkpoint metadata here to try mitigating logdir duplication
        self.try_checkpoint_metadata(trial)
        logger_creator = partial(noop_logger_creator, logdir=trial.logdir)

        if self._reuse_actors and self._cached_actor_pg[0] is not None:
            logger.debug(f"Trial {trial}: Reusing cached runner "
                         f"{self._cached_actor_pg[0]}")
            existing_runner, pg = self._cached_actor_pg
            self._cached_actor_pg = (None, None)

            trial.set_runner(existing_runner)
            if pg and trial.uses_placement_groups:
                self._pg_manager.assign_cached_pg(pg, trial)

            if not self.reset_trial(trial, trial.config, trial.experiment_tag,
                                    logger_creator):
                raise AbortTrialExecution(
                    "Trainable runner reuse requires reset_config() to be "
                    "implemented and return True.")
            return existing_runner

        if self._cached_actor_pg[0]:
            logger.debug("Cannot reuse cached runner {} for new trial".format(
                self._cached_actor_pg[0]))
            existing_runner, pg = self._cached_actor_pg

            if pg:
                self._pg_manager.return_or_clean_cached_pg(pg)

            with self._change_working_directory(trial):
                self._trial_cleanup.add(trial, actor=existing_runner)
            self._cached_actor_pg = (None, None)

        trainable_cls = trial.get_trainable_cls()
        if not trainable_cls:
            raise AbortTrialExecution(
                f"Invalid trainable: {trial.trainable_name}. If you passed "
                f"a string, make sure the trainable was registered before.")
        _actor_cls = _class_cache.get(trainable_cls)

        if trial.uses_placement_groups:
            if not self._pg_manager.has_ready(trial, update=True):
                if trial not in self._staged_trials:
                    if self._pg_manager.stage_trial_pg(trial):
                        self._staged_trials.add(trial)
                        self._just_staged_trials.add(trial)

                just_staged = trial in self._just_staged_trials

                # This part of the code is mostly here for testing
                # purposes. If self._wait_for_pg is set, we will wait here
                # for that many seconds until the placement group is ready.
                # This ensures that the trial can be started right away and
                # not just in the next step() of the trial runner.
                # We only do this if we have reason to believe that resources
                # will be ready, soon, i.e. when a) we just staged the PG,
                # b) another trial just exited, freeing resources, or c)
                # when there are no currently running trials.
                if self._wait_for_pg is not None and (
                        just_staged or self._trial_just_finished_before
                        or not self.get_running_trials()):
                    logger.debug(
                        f"Waiting up to {self._wait_for_pg} seconds for "
                        f"placement group of trial {trial} to become ready.")
                    wait_end = time.monotonic() + self._wait_for_pg
                    while time.monotonic() < wait_end:
                        self._pg_manager.update_status()
                        if self._pg_manager.has_ready(trial):
                            break
                        time.sleep(0.1)
                else:
                    return None

            if not self._pg_manager.has_ready(trial):
                # PG may have become ready during waiting period
                return None

            full_actor_class = self._pg_manager.get_full_actor_cls(
                trial, _actor_cls)
        else:
            full_actor_class = _actor_cls.options(
                num_cpus=trial.resources.cpu,
                num_gpus=trial.resources.gpu,
                memory=trial.resources.memory or None,
                object_store_memory=trial.resources.object_store_memory
                or None,
                resources=trial.resources.custom_resources)
        # Clear the Trial's location (to be updated later on result)
        # since we don't know where the remote runner is placed.
        trial.set_location(Location())
        logger.debug("Trial %s: Setting up new remote runner.", trial)
        # Logging for trials is handled centrally by TrialRunner, so
        # configure the remote runner to use a noop-logger.
        trial_config = copy.deepcopy(trial.config)
        trial_config[TRIAL_INFO] = TrialInfo(trial)

        stdout_file, stderr_file = trial.log_to_file
        trial_config[STDOUT_FILE] = stdout_file
        trial_config[STDERR_FILE] = stderr_file
        kwargs = {
            "config": trial_config,
            "logger_creator": logger_creator,
        }
        if issubclass(trial.get_trainable_cls(), DurableTrainable):
            kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
            kwargs["sync_function_tpl"] = trial.sync_to_cloud

        with self._change_working_directory(trial):
            return full_actor_class.remote(**kwargs)

    def _train(self, trial):
        """Start one iteration of training and save remote id."""
        if self._find_item(self._paused, trial):
            raise TuneError(
                "Should not call `train` on PAUSED trial {}. "
                "This is an internal error - please file an issue "
                "on https://github.com/ray-project/ray/issues/.".format(
                    str(trial)))

        if self._find_item(self._running, trial):
            logging.debug(
                "Trial {} already has a queued future. Skipping this "
                "`train` call. This may occur if a trial has "
                "been unpaused within a scheduler callback.".format(
                    str(trial)))
            return

        assert trial.status == Trial.RUNNING, trial.status
        buffer_time_s = max(
            self._buffer_min_time_s,
            min(self._buffer_max_time_s,
                len(self._running) // 10))
        with self._change_working_directory(trial):
            buffer_length = self._buffer_length

            # If buffer length has not been explicitly set, we determine
            # it automatically
            if buffer_length is None:
                if trial.checkpoint_at_end:
                    # If a trial checkpoint can be triggered externally,
                    # it is not safe to buffer results.
                    buffer_length = 1
                else:
                    # Else, use the default buffer length
                    buffer_length = self._default_buffer_length
            else:
                if trial.checkpoint_at_end:
                    if log_once("trial_executor_buffer_checkpoint"):
                        logger.warning(
                            "You passed `checkpoint_at_end` to `tune.run()`, "
                            "but still requested buffered training. "
                            "If used with a custom stopper or early stopping, "
                            "checkpoints may be created later than desired.")

            if buffer_length > 1:
                if trial.checkpoint_freq > 0:
                    buffer_length = min(buffer_length, trial.checkpoint_freq)
                remote = trial.runner.train_buffered.remote(
                    buffer_time_s, buffer_length)
            else:
                remote = trial.runner.train.remote()

        # Local Mode
        if isinstance(remote, dict):
            remote = _LocalWrapper(remote)

        self._running[remote] = trial
        trial_item = self._find_item(self._running, trial)
        assert len(trial_item) < 2, trial_item

    def _start_trial(self,
                     trial,
                     checkpoint=None,
                     runner=None,
                     train=True) -> bool:
        """Starts trial and restores last result if trial was paused.

        Args:
            trial (Trial): The trial to start.
            checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
                If None, and no trial checkpoint exists, the trial is started
                from the beginning.
            runner (Trainable): The remote runner to use. This can be the
                cached actor. If None, a new runner is created.
            train (bool): Whether or not to start training.

        Returns:
            True if trial was started successfully, False otherwise.

        See `RayTrialExecutor.restore` for possible errors raised.
        """
        prior_status = trial.status
        self.set_status(trial, Trial.PENDING)
        if runner is None:
            runner = self._setup_remote_runner(trial)
            if not runner:
                return False
        trial.set_runner(runner)
        self._notify_trainable_of_new_resources_if_needed(trial)
        self.restore(trial, checkpoint)
        self.set_status(trial, Trial.RUNNING)

        if trial in self._staged_trials:
            self._staged_trials.remove(trial)

        previous_run = self._find_item(self._paused, trial)
        if prior_status == Trial.PAUSED and previous_run:
            # If Trial was in flight when paused, self._paused stores result.
            self._paused.pop(previous_run[0])
            self._running[previous_run[0]] = trial
        elif train and not trial.is_restoring:
            self._train(trial)
        return True

    def _notify_trainable_of_new_resources_if_needed(self, trial: Trial):
        if trial.has_new_resources:
            trainable = trial.runner
            trial.has_new_resources = False
            with self._change_working_directory(trial):
                with warn_if_slow("update_resources"):
                    try:
                        ray.get(trainable._update_resources.remote(
                            trial.placement_group_factory if trial.
                            uses_placement_groups else trial.resources),
                                timeout=DEFAULT_GET_TIMEOUT)
                    except GetTimeoutError:
                        logger.exception(
                            "Trial %s: updating resources timed out.", trial)

    def _stop_trial(self,
                    trial: Trial,
                    error=False,
                    error_msg=None,
                    destroy_pg_if_cannot_replace=True):
        """Stops this trial.

        Stops this trial, releasing all allocating resources. If stopping the
        trial fails, the run will be marked as terminated in error, but no
        exception will be thrown.

        If the placement group will be used right away
        (destroy_pg_if_cannot_replace=False), we do not remove its placement
        group (or a surrogate placement group).

        Args:
            error (bool): Whether to mark this trial as terminated in error.
            error_msg (str): Optional error message.

        """
        self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
        self._trial_just_finished = True
        trial.set_location(Location())

        try:
            trial.write_error_log(error_msg)
            if hasattr(trial, "runner") and trial.runner:
                if (not error and self._reuse_actors
                        and self._cached_actor_pg[0] is None):
                    logger.debug("Reusing actor for %s", trial.runner)
                    # Move PG into cache (disassociate from trial)
                    pg = self._pg_manager.cache_trial_pg(trial)
                    if pg or not trial.uses_placement_groups:
                        # True if a placement group was replaced
                        self._cached_actor_pg = (trial.runner, pg)
                        should_destroy_actor = False
                    else:
                        # False if no placement group was replaced. This should
                        # only be the case if there are no more trials with
                        # this placement group factory to run
                        logger.debug(
                            "Could not cache of trial {trial} actor for "
                            "reuse, as there are no pending trials "
                            "requiring its resources.")
                        should_destroy_actor = True
                else:
                    should_destroy_actor = True

                if should_destroy_actor:
                    logger.debug("Trial %s: Destroying actor.", trial)

                    # Try to return the placement group for other trials to use
                    self._pg_manager.return_pg(trial,
                                               destroy_pg_if_cannot_replace)

                    with self._change_working_directory(trial):
                        self._trial_cleanup.add(trial, actor=trial.runner)

                if trial in self._staged_trials:
                    self._staged_trials.remove(trial)

        except Exception:
            logger.exception("Trial %s: Error stopping runner.", trial)
            self.set_status(trial, Trial.ERROR)
        finally:
            trial.set_runner(None)

    def start_trial(self,
                    trial: Trial,
                    checkpoint: Optional[Checkpoint] = None,
                    train: bool = True) -> bool:
        """Starts the trial.

        Will not return resources if trial repeatedly fails on start.

        Args:
            trial (Trial): Trial to be started.
            checkpoint (Checkpoint): A Python object or path storing the state
                of trial.
            train (bool): Whether or not to start training.

        Returns:
            True if the remote runner has been started. False if trial was
                not started (e.g. because of lacking resources/pending PG).
        """
        if not trial.uses_placement_groups:
            self._commit_resources(trial.resources)
        try:
            return self._start_trial(trial, checkpoint, train=train)
        except AbortTrialExecution:
            logger.exception("Trial %s: Error starting runner, aborting!",
                             trial)
            time.sleep(2)
            error_msg = traceback.format_exc()
            self._stop_trial(trial, error=True, error_msg=error_msg)
            return False
        except Exception:
            logger.exception("Trial %s: Unexpected error starting runner.",
                             trial)
            time.sleep(2)
            error_msg = traceback.format_exc()
            self._stop_trial(trial, error=True, error_msg=error_msg)
            # Note that we don't return the resources, since they may
            # have been lost. TODO(ujvl): is this the right thing to do?
            return False

    def _find_item(self, dictionary, item):
        out = [rid for rid, t in dictionary.items() if t is item]
        return out

    def stop_trial(self,
                   trial: Trial,
                   error: bool = False,
                   error_msg: Optional[str] = None,
                   destroy_pg_if_cannot_replace: bool = True) -> None:
        """Only returns resources if resources allocated.

        If destroy_pg_if_cannot_replace is False, the Trial placement group
        will not be removed if it can't replace any staging ones."""
        prior_status = trial.status
        self._stop_trial(
            trial,
            error=error,
            error_msg=error_msg,
            destroy_pg_if_cannot_replace=destroy_pg_if_cannot_replace)
        if prior_status == Trial.RUNNING:
            logger.debug("Trial %s: Returning resources.", trial)
            if not trial.uses_placement_groups:
                self._return_resources(trial.resources)
            out = self._find_item(self._running, trial)
            for result_id in out:
                self._running.pop(result_id)

    def continue_training(self, trial: Trial) -> None:
        """Continues the training of this trial."""
        self._train(trial)

    def pause_trial(self, trial: Trial) -> None:
        """Pauses the trial.

        If trial is in-flight, preserves return value in separate queue
        before pausing, which is restored when Trial is resumed.
        """
        trial_future = self._find_item(self._running, trial)
        if trial_future:
            self._paused[trial_future[0]] = trial
        super(RayTrialExecutor, self).pause_trial(trial)

    def reset_trial(
        self,
        trial: Trial,
        new_config: Dict,
        new_experiment_tag: str,
        logger_creator: Optional[Callable[[Dict], "ray.tune.Logger"]] = None
    ) -> bool:
        """Tries to invoke `Trainable.reset()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial trainable.
            new_experiment_tag (str): New experiment name for trial.
            logger_creator (Optional[Callable[[Dict], Logger]]): Function
                that instantiates a logger on the actor process.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.set_experiment_tag(new_experiment_tag)
        trial.set_config(new_config)
        trainable = trial.runner

        # Pass magic variables
        extra_config = copy.deepcopy(new_config)
        extra_config[TRIAL_INFO] = TrialInfo(trial)

        stdout_file, stderr_file = trial.log_to_file
        extra_config[STDOUT_FILE] = stdout_file
        extra_config[STDERR_FILE] = stderr_file

        with self._change_working_directory(trial):
            with warn_if_slow("reset"):
                try:
                    reset_val = ray.get(trainable.reset.remote(
                        extra_config, logger_creator),
                                        timeout=DEFAULT_GET_TIMEOUT)
                except GetTimeoutError:
                    logger.exception("Trial %s: reset timed out.", trial)
                    return False
        return reset_val

    def get_running_trials(self) -> List[Trial]:
        """Returns the running trials."""
        return list(self._running.values())

    def get_alive_node_ips(self):
        now = time.time()
        if now - self._last_ip_refresh < self._refresh_period:
            return self._last_ip_addresses
        logger.debug("Checking ips from Ray state.")
        self._last_ip_refresh = now
        nodes = ray.state.nodes()
        ip_addresses = set()
        for node in nodes:
            if node["alive"]:
                ip_addresses.add(node["NodeManagerAddress"])
        self._last_ip_addresses = ip_addresses
        return ip_addresses

    def get_current_trial_ips(self):
        return {t.node_ip for t in self.get_running_trials()}

    def get_next_failed_trial(self) -> Optional[Trial]:
        """Gets the first trial found to be running on a node presumed dead.

        Returns:
            A Trial object that is ready for failure processing. None if
            no failure detected.
        """
        if ray.worker._mode() != ray.worker.LOCAL_MODE:
            live_cluster_ips = self.get_alive_node_ips()
            if live_cluster_ips - self.get_current_trial_ips():
                for trial in self.get_running_trials():
                    if trial.node_ip and trial.node_ip not in live_cluster_ips:
                        return trial
        return None

    def get_next_available_trial(self,
                                 timeout: Optional[float] = None
                                 ) -> Optional[Trial]:
        if not self._running:
            return None
        shuffled_results = list(self._running.keys())
        random.shuffle(shuffled_results)

        # Note: We shuffle the results because `ray.wait` by default returns
        # the first available result, and we want to guarantee that slower
        # trials (i.e. trials that run remotely) also get fairly reported.
        # See https://github.com/ray-project/ray/issues/4211 for details.
        start = time.time()
        ready, _ = ray.wait(shuffled_results, timeout=timeout)
        if not ready:
            return None
        result_id = ready[0]
        wait_time = time.time() - start
        if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
            self._last_nontrivial_wait = time.time()
        if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
            logger.warning(
                "Over the last {} seconds, the Tune event loop has been "
                "backlogged processing new results. Consider increasing your "
                "period of result reporting to improve performance.".format(
                    BOTTLENECK_WARN_PERIOD_S))

            self._last_nontrivial_wait = time.time()
        return self._running[result_id]

    def fetch_result(self, trial) -> List[Trial]:
        """Fetches result list of the running trials.

        Returns:
            Result of the most recent trial training run.
        """
        trial_future = self._find_item(self._running, trial)
        if not trial_future:
            raise ValueError("Trial was not running.")
        self._running.pop(trial_future[0])
        with warn_if_slow("fetch_result"):
            result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)

        # For local mode
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()

        if not isinstance(result, list):
            return [result]
        return result

    def _commit_resources(self, resources):
        committed = self._committed_resources
        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) + resources.get_res_total(k)
            for k in all_keys
        }

        self._committed_resources = Resources(
            committed.cpu + resources.cpu_total(),
            committed.gpu + resources.gpu_total(),
            committed.memory + resources.memory_total(),
            committed.object_store_memory +
            resources.object_store_memory_total(),
            custom_resources=custom_resources)

    def _return_resources(self, resources):
        committed = self._committed_resources

        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) - resources.get_res_total(k)
            for k in all_keys
        }
        self._committed_resources = Resources(
            committed.cpu - resources.cpu_total(),
            committed.gpu - resources.gpu_total(),
            custom_resources=custom_resources)

        assert self._committed_resources.is_nonnegative(), (
            "Resource invalid: {}".format(resources))

    def _update_avail_resources(self, num_retries=5):
        if time.time() - self._last_resource_refresh < self._refresh_period:
            return
        logger.debug("Checking Ray cluster resources.")
        resources = None
        for i in range(num_retries):
            if i > 0:
                logger.warning(
                    "Cluster resources not detected or are 0. Attempt #"
                    "%s...", i + 1)
                time.sleep(0.5)
            try:
                resources = ray.cluster_resources()
            except Exception as exc:
                # TODO(rliaw): Remove this when local mode is fixed.
                # https://github.com/ray-project/ray/issues/4147
                logger.debug(f"{exc}: Using resources for local machine.")
                resources = ResourceSpec().resolve(True).to_resource_dict()
            if resources:
                break

        if not resources:
            # NOTE: This hides the possibility that Ray may be waiting for
            # clients to connect.
            resources.setdefault("CPU", 0)
            resources.setdefault("GPU", 0)
            logger.warning("Cluster resources cannot be detected or are 0. "
                           "You can resume this experiment by passing in "
                           "`resume=True` to `run`.")

        resources = resources.copy()
        num_cpus = resources.pop("CPU", 0)
        num_gpus = resources.pop("GPU", 0)
        memory = ray_constants.from_memory_units(resources.pop("memory", 0))
        object_store_memory = ray_constants.from_memory_units(
            resources.pop("object_store_memory", 0))
        custom_resources = resources

        self._avail_resources = Resources(
            int(num_cpus),
            int(num_gpus),
            memory=int(memory),
            object_store_memory=int(object_store_memory),
            custom_resources=custom_resources)
        self._last_resource_refresh = time.time()
        self._resources_initialized = True

    def has_resources_for_trial(self, trial: Trial) -> bool:
        """Returns whether this runner has resources available for this trial.

        If using placement groups, this will return True as long as we
        didn't reach the maximum number of pending trials. It will also return
        True if the trial placement group is already staged.

        Args:
            trial: Trial object which should be scheduled.

        Returns:
            boolean

        """
        if trial.uses_placement_groups:
            return trial in self._staged_trials or self._pg_manager.can_stage(
            ) or self._pg_manager.has_ready(trial, update=True)

        return self.has_resources(trial.resources)

    def has_resources(self, resources: Resources) -> bool:
        """Returns whether this runner has at least the specified resources.

        This refreshes the Ray cluster resources if the time since last update
        has exceeded self._refresh_period. This also assumes that the
        cluster is not resizing very frequently.
        """
        if resources.has_placement_group:
            return self._pg_manager.can_stage()

        self._update_avail_resources()
        currently_available = Resources.subtract(self._avail_resources,
                                                 self._committed_resources)
        have_space = (
            resources.cpu_total() <= currently_available.cpu
            and resources.gpu_total() <= currently_available.gpu
            and resources.memory_total() <= currently_available.memory
            and resources.object_store_memory_total() <=
            currently_available.object_store_memory and all(
                resources.get_res_total(res) <= currently_available.get(res)
                for res in resources.custom_resources))

        if have_space:
            # The assumption right now is that we block all trials if one
            # trial is queued.
            self._trial_queued = False
            return True

        can_overcommit = self._queue_trials and not self._trial_queued
        if can_overcommit:
            self._trial_queued = True
            logger.warning(
                "Allowing trial to start even though the "
                "cluster does not have enough free resources. Trial actors "
                "may appear to hang until enough resources are added to the "
                "cluster (e.g., via autoscaling). You can disable this "
                "behavior by specifying `queue_trials=False` in "
                "ray.tune.run().")
            return True

        return False

    def debug_string(self) -> str:
        """Returns a human readable message for printing to the console."""
        total_resources = self._pg_manager.total_used_resources(
            self._committed_resources)

        if self._resources_initialized:
            status = ("Resources requested: {}/{} CPUs, {}/{} GPUs, "
                      "{}/{} GiB heap, {}/{} GiB objects".format(
                          total_resources.pop("CPU",
                                              0), self._avail_resources.cpu,
                          total_resources.pop("GPU", 0),
                          self._avail_resources.gpu,
                          _to_gb(total_resources.pop("memory", 0.)),
                          _to_gb(self._avail_resources.memory),
                          _to_gb(total_resources.pop("object_store_memory",
                                                     0.)),
                          _to_gb(self._avail_resources.object_store_memory)))
            customs = ", ".join([
                "{}/{} {}".format(total_resources.get(name, 0.),
                                  self._avail_resources.get_res_total(name),
                                  name)
                for name in self._avail_resources.custom_resources
                if not name.startswith(ray.resource_spec.NODE_ID_PREFIX) and (
                    total_resources.get(name, 0.) > 0 or "_group_" not in name)
            ])
            if customs:
                status += " ({})".format(customs)
            return status
        else:
            return "Resources requested: ?"

    def resource_string(self) -> str:
        """Returns a string describing the total resources available."""
        if self._resources_initialized:
            res_str = ("{} CPUs, {} GPUs, "
                       "{} GiB heap, {} GiB objects".format(
                           self._avail_resources.cpu,
                           self._avail_resources.gpu,
                           _to_gb(self._avail_resources.memory),
                           _to_gb(self._avail_resources.object_store_memory)))
            if self._avail_resources.custom_resources:
                custom = ", ".join(
                    "{} {}".format(self._avail_resources.get_res_total(name),
                                   name)
                    for name in self._avail_resources.custom_resources)
                res_str += " ({})".format(custom)
            return res_str
        else:
            return "? CPUs, ? GPUs"

    def on_step_begin(self, trials: List[Trial]) -> None:
        """Before step() is called, update the available resources."""
        self._update_avail_resources()
        self._trial_just_finished_before = self._trial_just_finished
        self._trial_just_finished = False

    def on_step_end(self, trials: List[Trial]) -> None:
        self._just_staged_trials.clear()

        if time.time() > self.last_pg_recon + self.pg_recon_interval:
            # Only do this every now and then - usually the placement groups
            # should not get out of sync, and calling this often is inefficient
            self._pg_manager.reconcile_placement_groups(trials)
            self.last_pg_recon = time.time()

        self._pg_manager.cleanup()

    def force_reconcilation_on_next_step_end(self) -> None:
        self.last_pg_recon = -float("inf")

    def save(self,
             trial,
             storage=Checkpoint.PERSISTENT,
             result: Optional[Dict] = None) -> Checkpoint:
        """Saves the trial's state to a checkpoint asynchronously.

        Args:
            trial (Trial): The trial to be saved.
            storage (str): Where to store the checkpoint. Defaults to
                PERSISTENT.
            result (dict): The state of this trial as a dictionary to be saved.
                If result is None, the trial's last result will be used.

        Returns:
             Checkpoint object, or None if an Exception occurs.
        """
        result = result or trial.last_result
        with self._change_working_directory(trial):
            if storage == Checkpoint.MEMORY:
                value = trial.runner.save_to_object.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.on_checkpoint(checkpoint)
            else:
                value = trial.runner.save.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.saving_to = checkpoint
                self._running[value] = trial
        return checkpoint

    def restore(self, trial, checkpoint=None, block=False) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial (Trial): The trial to be restored.
            checkpoint (Checkpoint): The checkpoint to restore from. If None,
                the most recent PERSISTENT checkpoint is used. Defaults to
                None.
            block (bool): Whether or not to block on restore before returning.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if issubclass(trial.get_trainable_cls(),
                          DurableTrainable) or not trial.sync_on_checkpoint:
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where a DurableTrainable is not provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                obj = TrainableUtil.checkpoint_to_object(value)
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` and a Trainable "
                    "extending `DurableTrainable` for remote storage-based "
                    "restoration")

            if block:
                ray.get(remote)
            else:
                self._running[remote] = trial
                trial.restoring_from = checkpoint

    def export_trial_if_needed(self, trial: Trial) -> Dict:
        """Exports model of this trial based on trial.export_formats.

        Return:
            A dict that maps ExportFormats to successfully exported models.
        """
        if trial.export_formats and len(trial.export_formats) > 0:
            with self._change_working_directory(trial):
                return ray.get(trial.runner.export_model.remote(
                    trial.export_formats),
                               timeout=DEFAULT_GET_TIMEOUT)
        return {}

    def has_gpus(self) -> bool:
        if self._resources_initialized:
            self._update_avail_resources()
            return self._avail_resources.gpu > 0

    def cleanup(self, trials: List[Trial]) -> None:
        self._trial_cleanup.cleanup(partial=False)
        self._pg_manager.reconcile_placement_groups(trials)
        self._pg_manager.cleanup(force=True)
        self._pg_manager.cleanup_existing_pg(block=True)

    @contextmanager
    def _change_working_directory(self, trial):
        """Context manager changing working directory to trial logdir.
        Used in local mode.

        For non-local mode it is no-op.
        """
        if ray.worker._mode() == ray.worker.LOCAL_MODE:
            old_dir = os.getcwd()
            try:
                os.chdir(trial.logdir)
                yield
            finally:
                os.chdir(old_dir)
        else:
            yield
示例#4
0
class FluidExecutor(TrialExecutor):
    def __init__(self, **kwargs):
        super().__init__(queue_trials=True)  # type: ignore

        # whether in testing environment without GPU
        self._fake_gpus = False

        # resources
        self._avail_resources = Resources(cpu=0, gpu=0)
        self._committed_resources = Resources(cpu=0, gpu=0)
        self._resources_initialized = False
        self._last_resource_refresh = float("-inf")
        # list of trials that has resources committed
        # this is usually those trials in jobs_running,
        # but a trial may be only in _trials_running but not in jobs_running,
        # because fetch_result was called on it.
        # This is maintained solely by _commit_resources/_return_resources
        self._trials_running: List[Trial] = set()

        # make sure our own GPU resources are created first in the cluster
        create_custom_gpu_res()
        self._update_avail_resources()

        logger.info(f"Init with resources: {self._avail_resources}")

        self.jobs_pending: List[PendingJob] = []

        # map from in_flight_future to the job
        self.jobs_running: Dict[ray.ObjectID, RunningJob] = {}

        # used to save the previous run fut
        self.jobs_paused: Dict[ray.ObjectID, RunningJob] = {}

        # async queue to stop runner
        self._trial_cleanup = _TrialCleanup()

        # metadata about a trial group
        self.trial_group_meta: List[TrialGroupMeta] = []
        # trialgroup assignment,
        # mapping from trial_id to group num
        self.trial_groups: Dict[str, TrialAndGroup] = {}

    @property
    def num_trial_groups(self) -> int:
        return len(self.trial_group_meta)

    def _detect_groups(self):
        """Go over pending jobs, and assign trialgroup to them if not already done.
        If new groups are discovered, otherwise run static
        """
        logger.debug(
            f"_detect_groups: self.jobs_pending={self.jobs_pending} self.trial_groups={self.trial_groups}"
        )
        # pending may already be assigned a group if it's an unpaused trial
        assigned, unassigned = partition(
            self.jobs_pending, lambda p: p.trial.trial_id in self.trial_groups)
        unassigned = list(unassigned)
        assigned = list(assigned)
        self.jobs_pending.clear()
        if unassigned:
            meta = TrialGroupMeta(
                self.num_trial_groups,
                unassigned,
            )
            self.trial_group_meta.append(meta)
            logger.debug("Assign group %d to unassigned trials: %s", meta.grp,
                         unassigned)
            for p in unassigned:
                self.trial_groups[p.trial.trial_id] = TrialAndGroup(
                    p.trial, meta.grp)
            # allocate reousrces
            self._fluid(meta)
        else:
            logger.debug("No new group")

        if assigned:
            # find each group with pending jobs and do dynamic
            groups = {self._find_group(p.trial) for p in assigned}
            for meta in groups:
                self._fluid(meta)
        else:
            logger.debug("No change in existing groups")

    def _dump_groups(self):
        """Dump group info for debugging"""
        logger.info("There are %d TrialGroup(s)", self.num_trial_groups)
        for grp in range(self.num_trial_groups):
            logger.info("TrialGroup %d", grp)
            for trial in self._trial_group(grp):
                if self._find_running(trial):
                    tag = "jobs_running"
                elif self._find_pending(trial):
                    tag = "jobs_pending"
                elif self._find_paused(trial):
                    tag = "jobs_paused"
                else:
                    tag = "none"
                logger.info("    Trial %s: [%s] queue [%s]", trial.trial_id,
                            trial.status, tag)
        logger.info("Idle Resources: %s",
                    self._resource_string(self.idle_resources))

    def _committed_resources_in_group(self, grp: int) -> Resources:
        """Compute all resources committed in this group"""
        used = Resources(cpu=0, gpu=0)
        for job in self.jobs_running.values():
            if job.trial.trial_id in self.trial_groups:
                used = resources_add(used, job.trial.resources)
        return used

    def _fluid(self, meta: TrialGroupMeta):
        """Run fluid on a specific group"""
        self._dump_groups()
        # set of trials to consider
        A = {trial.trial_id for trial in self._trial_group(meta.grp)}
        logger.debug(
            f"_fluid: meta.perf.trials_missing_info={meta.perf.trials_missing_info} meta.trials={meta.trials}, meta.grp={meta.grp}, trial_groups={self.trial_groups}, A={A}"
        )
        # assignment of resources
        W: Dict[str, Resources] = {}
        # compute new idle resources if every trials in this group were stopped
        M = resources_add(self.idle_resources,
                          self._committed_resources_in_group(meta.grp))

        if meta.perf.trials_missing_info:
            # there are still trials need perf data,
            # restrict A to only these trials
            others = A.difference(meta.perf.trials_missing_info)
            A = meta.perf.trials_missing_info
            # set others to use 0 resource
            for tid in others:
                W[tid] = Resources(cpu=0, gpu=0)
            # use 1 gpu per trial to get reference perf
            for tid in A:
                r = Resources(cpu=1, gpu=1)
                Mp = Resources.subtract(M, r)
                if not Mp.is_nonnegative():
                    break
                M = Mp
                W[tid] = r
        else:
            # convert A to array for sorting
            A = np.array(list(A))
            # reference height (1 width)
            H1 = np.array([meta.perf.get_height(tid, 1) for tid in A])
            # sort by H1 in non-increasing order
            ord = np.argsort(H1[::-1])
            A = A[ord]
            H1 = H1[ord]
            # $$w_i= \min(
            #   \max(
            #       \floor{
            #           \frac{h_{i,1}}{\sum_j h_{j,1} } n
            #       },
            #       \frac{1}{c}),
            #   d
            # )$$
            c = 1 / 2
            d = 4
            w = np.minimum(
                np.maximum(np.floor(H1 * np.size(H1) / np.sum(H1)), 1 / c), d)
            # assign resources based on w
            w = w / w.sum() * self._avail_resources.gpu_total()
            resW = [Resources(cpu=1, gpu=g) for g in w]
            # write to W
            W = dict(zip(A, resW))

        self._ensure_W(W, meta)

    def _ensure_W(self, W: Dict[str, Resources], meta: TrialGroupMeta):
        """Adjust group resources given in W"""
        logger.debug(f"ensure_W: W={W} meta.trials={meta.trials}")
        # stop any trials with 0 res
        # this has to be done first to free up resources for others to use
        for trial_id, res in W.items():
            trial = self.trial_groups[trial_id].trial
            if res.cpu_total() + res.gpu_total() == 0:
                # add to paused, then ensure_stop, we do not change trial's status which is visible outside
                running = self._find_running(trial)
                if running is not None:
                    # don't call pause_trial, which will trigger another fluid reschedule
                    self.jobs_paused[running.in_flight_future] = running
                self._ensure_stop(running.trial)
                trial.resources = res
                # add to pending
                self.start_trial(trial)
        # adjust any trials with different res, including any not already running
        for trial_id, res in W.items():
            # use trial group to map trial_id to trial
            trial = self.trial_groups[trial_id].trial

            if res.cpu_total() + res.gpu_total() == 0:
                # already handled in the loop above
                continue

            if (
                    # current_res != res
                    Resources.subtract(trial.resources, res).is_nonnegative()
                    != Resources.subtract(res,
                                          trial.resources).is_nonnegative()):
                running = self._find_running(trial)
                if running is not None:
                    # don't call pause_trial, which will trigger another fluid reschedule
                    self.jobs_paused[running.in_flight_future] = running

                self._ensure_stop(trial)

            # at this point, the job is always stopped but not in the pending queue,
            # because fluid clears the pending queue.
            trial.resources = res
            self._kickoff(PendingJob(trial, None, True), res)

    def _find_group(self, trial: Trial) -> TrialGroupMeta:
        return self.trial_group_meta[self.trial_groups[trial.trial_id].group]

    def _trial_group(self, grp: int) -> List[Trial]:
        return [v.trial for v in self.trial_groups.values() if v.group == grp]

    def _find_paused(self, trial: Trial) -> Optional[RunningJob]:
        for job in self.jobs_paused.values():
            if job.trial == trial:
                return job

    def _pop_paused(self, trial: Trial) -> Optional[RunningJob]:
        for fut, job in self.jobs_paused.items():
            if job.trial == trial:
                assert fut == job.in_flight_future
                return self.jobs_paused.pop(fut)

    def _find_running(self, trial: Trial) -> Optional[RunningJob]:
        for _, job in self.jobs_running.items():
            if job.trial == trial:
                return job
        logger.debug(
            f"Cloud not find running trial: {trial}, currently running ones are {[job for _, job in self.jobs_running.items()]}"
        )

    def _find_pending(self, trial: Trial) -> Optional[PendingJob]:
        for job in self.jobs_pending:
            if job.trial == trial:
                return job

    def _setup_remote_runner(self, trial: Trial, res: Resources,
                             reuse_allowed: bool) -> Any:
        trial.init_logger()
        # We checkpoint metadata here to try mitigating logdir duplication
        self.try_checkpoint_metadata(trial)
        remote_logdir = trial.logdir

        cls = ray.remote(
            num_cpus=res.cpu,
            num_gpus=0 if self._fake_gpus else res.gpu,
            memory=res.memory,
            object_store_memory=res.object_store_memory,
            resources=res.custom_resources,
        )(trial.get_trainable_cls())

        def logger_creator(config):
            # Set the working dir in the remote process, for user file writes
            os.makedirs(remote_logdir, exist_ok=True)
            if not ray.worker._mode() == ray.worker.LOCAL_MODE:
                os.chdir(remote_logdir)
            return NoopLogger(config, remote_logdir)

        # Clear the Trial's location (to be updated later on result)
        # since we don't know where the remote runner is placed.
        trial.set_location(Location())
        logger.debug("Trial %s: Setting up new remote runner.", trial)
        # Logging for trials is handled centrally by TrialRunner, so
        # configure the remote runner to use a noop-logger.
        trial_config = copy.deepcopy(trial.config)
        trial_config[TRIAL_INFO] = TrialInfo(trial)
        kwargs = {
            "config": trial_config,
            "logger_creator": logger_creator,
        }
        if issubclass(trial.get_trainable_cls(), DurableTrainable):
            kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir

        with _change_working_directory(trial):
            return cls.remote(**kwargs)

    def _kickoff(self, pending: PendingJob,
                 res: Resources) -> Optional[RunningJob]:
        """Turn a pending job into a running one
        The pending job may be previously paused, or completely new.
        If paused, there will be a running job saved in the jobs_paused queue

        May return None if failed to start
        """
        trial = pending.trial
        # this is needed for the Trainer to setup distributed training
        # TODO: figure what config key is also needed to set resource info
        trial.resources = res

        self._commit_resources(trial)
        try:
            reuse_allowed = pending.checkpoint is not None or trial.has_checkpoint(
            )
            runner = self._setup_remote_runner(trial, res, reuse_allowed)
            trial.set_runner(runner)
            restore_job = self._restore(trial, pending.checkpoint)

            # trial's status is already RUNNING, set in start_trial, to fake a running trial from the outside

            # if previously is paused
            prev_run = self._pop_paused(trial)
            if prev_run is not None:
                if restore_job is not None:
                    logger.error(
                        "A previously paused job is restoring!!!, blocking on restoring"
                    )
                    ray.get(restore_job.in_flight_future)
                # add back to running queue
                self.jobs_running[prev_run.in_flight_future] = prev_run
                return prev_run

            # if is restoring
            if trial.is_restoring:
                # assert restore_job is not None
                return restore_job

            # actually start train op
            return self._ensure_train(trial)
        except Exception as e:
            if isinstance(e, AbortTrialExecution):
                logger.exception("Trial %s: Error starting runner, aborting!",
                                 trial)
            else:
                logger.exception("Trial %s: Unexpected error starting runner.",
                                 trial)
            time.sleep(2)
            error_msg = traceback.format_exc()
            self._ensure_stop(
                trial,
                error=True,
                error_msg=error_msg,
                stop_logger=True,
                # NOTE that we don't return the resources, since they may have been lost.
                release_resources=False,
                update_status=True,
            )

    def _ensure_train(self, trial: Trial) -> RunningJob:
        """Actually invoke the train op on the runner"""
        assert trial.runner is not None
        with _change_working_directory(trial):
            fut = trial.runner.train.remote()

        if isinstance(fut, dict):
            # local mode
            fut = _LocalWrapper(fut)
        running = RunningJob(trial, fut)
        self.jobs_running[fut] = running
        logger.debug(
            f"Set trial to running: {trial}, jobs_running={self.jobs_running}")
        return running

    def _ensure_stop(
        self,
        trial,
        error=False,
        error_msg="",
        stop_logger=True,
        release_resources=True,
        update_status=False,
    ):
        """Stops the trial and its logger
        Handles any error
        """
        logger.debug(f"_ensure_stop: trial.resources={trial.resources}")
        if stop_logger:
            trial.close_logger()

        prior_status = trial.status
        trial.set_location(Location())
        if update_status:
            self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)

        # remove from running
        in_flight = [
            j for _, j in self.jobs_running.items() if j.trial == trial
        ]
        for j in in_flight:
            self.jobs_running.pop(j.in_flight_future)
        if in_flight:
            if prior_status not in [Trial.RUNNING, Trial.ERROR]:
                assert False, "trial status invalid"
        # release resources
        if release_resources:
            self._return_resources(trial)

        # remove from trial group
        # del self.trial_groups[trial.trial_id]

        try:
            trial.write_error_log(error_msg)
            if hasattr(trial, "runner") and trial.runner:
                logger.debug("Trial %s: Destroying actor.", trial)
                with _change_working_directory(trial):
                    self._trial_cleanup.add(trial, actor=trial.runner)
        except Exception:
            logger.exception("Trial %s: Error stopping runner.", trial)
            self.set_status(trial, Trial.ERROR)
        finally:
            trial.set_runner(None)

    def has_resources(self, resources):
        """Tell the schedule algorithm to always submit trials to us"""
        return True

    def start_trial(self, trial, checkpoint=None, train=True):
        """Add to pending queue and reschedule"""
        logger.debug("start_trial %s", trial)
        # the trial is considered by the outside to be running
        self.set_status(trial, Trial.RUNNING)
        self.jobs_pending.append(PendingJob(trial, checkpoint, train))
        # The actual triggering is done in on_no_available_trials()

    def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
        """Add to to-stop queue and reschedule"""
        logger.debug("stop_trial %s", trial)
        self._ensure_stop(trial,
                          error,
                          error_msg,
                          stop_logger,
                          update_status=True)
        meta = self._find_group(trial)
        self._fluid(meta)

    def continue_training(self, trial):
        # this is called after got results from a trial,
        # and should start another train op in place of the
        # finished one.
        running_job = self._find_running(trial)
        if running_job is not None:
            # skip if the trial is running
            return

        # start new one
        self._ensure_train(trial)

    def pause_trial(self, trial):
        logger.debug("pause_trial %s", trial)
        running = self._find_running(trial)
        if running is not None:
            # add to jobs_paused
            self.jobs_paused[running.in_flight_future] = running
        # the super impl will call stop trial, which will then remove the job from running queue
        super().pause_trial(trial)

    def unpause_trial(self, trial):
        logger.debug("unpause_trial %s", trial)
        super().unpause_trial(trial)

    def resume_trial(self, trial):
        """Resumes PAUSED trials. This is a blocking call.
        This is not used by any algorithm
        """
        logger.debug("resume_trial %s", trial)
        assert trial.status == Trial.PAUSED, trial.status
        raise NotImplementedError

    def reset_trial(self, trial: Trial, new_config, new_experiment_tag):
        """Tries to invoke `Trainable.reset_config()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial trainable.
            new_experiment_tag (str): New experiment name for trial.

        Returns:
            True if `reset_config` is successful else False.
        """
        logger.debug("reset_trial %s", trial)
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with _change_working_directory(trial):
            try:
                reset_val = ray.get(trainable.reset_config.remote(new_config),
                                    DEFAULT_GET_TIMEOUT)
            except RayTimeoutError:
                logger.exception("Trial %s: reset_config timed out.", trial)
                return False
        return reset_val

    def get_running_trials(self):
        return [job.trial for job in self.jobs_running.values()]

    def get_next_available_trial(self) -> Trial:
        """Return the next trial with ready result.
        Note that this doesn't remove the trial from running, fetch_result does that
        """
        futures = list(self.jobs_running.keys())
        # shuffle the list of futures because ray.wait
        # always return the first available future, but we want to be fair
        random.shuffle(futures)
        [ready_fut], _ = ray.wait(futures, num_returns=1)
        return self.jobs_running[ready_fut].trial

    def get_next_failed_trial(self) -> Optional[Trial]:
        if ray.worker._mode() == ray.worker.LOCAL_MODE:
            return None

        alive_node_ips = {
            node["NodeManagerAddress"]
            for node in ray.state.nodes() if node["alive"]
        }
        for trial in self.get_running_trials():
            if trial.node_ip and trial.node_ip not in alive_node_ips:
                return trial
        return None

    def fetch_result(self, trial):
        """
        Note that this will remove the trial from running queue,
        so actions must be taken later to either continue_training/stop/pause,
        to maintain consistent system state.

        This is usually called from the runner, knowning the the future for this trial is ready.
        """
        running_job = self._find_running(trial)
        assert running_job, "Trial was not running"
        self.jobs_running.pop(running_job.in_flight_future)
        result = ray.get(running_job.in_flight_future, DEFAULT_GET_TIMEOUT)
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()

        if isinstance(result, dict):
            # notify trial group
            meta = self._find_group(trial)
            meta.perf.on_trial_result(trial.trial_id, result)
        return result

    def debug_string(self):
        # TODO debug_string
        pass

    def _resource_string(self, res: Resources) -> str:
        """Returns a string describing the total resources available."""
        res_str = (f"{res.cpu} CPUs, {res.gpu} GPUs, "
                   f"{_to_gb(res.memory)} GiB heap, "
                   f"{_to_gb(res.object_store_memory)} GiB objects")
        if res.custom_resources:
            custom = ", ".join(f"{res.get_res_total(name)} {name}"
                               for name in res.custom_resources)
            res_str += f" ({custom})"
        return res_str

    def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
        """Saves the trial's state to a checkpoint asynchronously.

        Args:
            trial (Trial): The trial to be saved.
            storage (str): Where to store the checkpoint. Defaults to
                PERSISTENT.
            result (dict): The state of this trial as a dictionary to be saved.
                If result is None, the trial's last result will be used.

        Returns:
             Checkpoint object, or None if an Exception occurs.
        """
        result = result or trial.last_result
        with _change_working_directory(trial):
            if storage == Checkpoint.MEMORY:
                value = trial.runner.save_to_object.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.on_checkpoint(checkpoint)
            else:
                value = trial.runner.save.remote()
                checkpoint = Checkpoint(storage, value, result)
                trial.saving_to = checkpoint
                self.jobs_running[value] = RunningJob(trial, value)
        return checkpoint

    def _restore(self,
                 trial,
                 checkpoint=None,
                 block=False) -> Optional[RunningJob]:
        """Restores training state from a given model checkpoint.

        Args:
            trial (Trial): The trial to be restored.
            checkpoint (Checkpoint): The checkpoint to restore from. If None,
                the most recent PERSISTENT checkpoint is used. Defaults to
                None.
            block (bool): Whether or not to block on restore before returning.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with _change_working_directory(trial):
                trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if issubclass(trial.get_trainable_cls(), DurableTrainable):
                with _change_working_directory(trial):
                    remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where a DurableTrainable is not provided.
                logger.warning("Trial %s: Reading checkpoint into memory.",
                               trial)
                data_dict = TrainableUtil.pickle_checkpoint(value)
                with _change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(data_dict)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` and a Trainable "
                    "extending `DurableTrainable` for remote storage-based "
                    "restoration")

            if block:
                ray.get(remote)
            else:
                trial.restoring_from = checkpoint
                running_job = RunningJob(trial, remote)
                self.jobs_running[remote] = running_job
                return running_job

    def restore(self, trial, checkpoint=None, block=False):
        return self._restore(trial, checkpoint, block)

    def export_trial_if_needed(self, trial: Trial):
        """Exports model of this trial based on trial.export_formats.

        Return:
            A dict that maps ExportFormats to successfully exported models.
        """
        if trial.export_formats and len(trial.export_formats) > 0:
            with _change_working_directory(trial):
                return ray.get(
                    trial.runner.export_model.remote(trial.export_formats),
                    DEFAULT_GET_TIMEOUT,
                )
        return {}

    def cleanup(self):
        self._trial_cleanup.cleanup(partial=False)

    def on_step_begin(self, trial_runner):
        """Before step() called, update the available resources."""
        self._update_avail_resources()

    def _update_avail_resources(self, num_retries=5):
        resources = None
        for i in range(num_retries):
            if i > 0:
                logger.warning(
                    "Cluster resources not detected or are 0. Attempt #"
                    "%s...", i + 1)
                time.sleep(0.5)
            try:
                resources = ray.cluster_resources()
            except Exception:
                # TODO(rliaw): Remove this when local mode is fixed.
                # https://github.com/ray-project/ray/issues/4147
                logger.debug("Using resources for local machine.")
                resources = ResourceSpec().resolve(True).to_resource_dict()
            if resources:
                break

        if not resources:
            # NOTE: This hides the possibility that Ray may be waiting for
            # clients to connect.
            resources.setdefault("CPU", 0)
            resources.setdefault("GPU", 0)
            logger.warning("Cluster resources cannot be detected or are 0. "
                           "You can resume this experiment by passing in "
                           "`resume=True` to `run`.")

        resources = resources.copy()
        num_cpus = resources.pop("CPU", 0)
        num_gpus = resources.pop("GPU", 0)
        memory = ray_constants.from_memory_units(resources.pop("memory", 0))
        object_store_memory = ray_constants.from_memory_units(
            resources.pop("object_store_memory", 0))
        custom_resources = resources

        if num_gpus == 0:
            warnings.warn(
                "No GPU resources found, assuming local test, using CPU resources instead"
            )
            # local test
            num_gpus = num_cpus
            self._fake_gpus = True
        else:
            self._fake_gpus = False

        avail_resources = Resources(
            int(num_cpus),
            int(num_gpus),
            memory=int(memory),
            object_store_memory=int(object_store_memory),
            custom_resources=custom_resources,
        )

        assert (self.idle_resources.is_nonnegative()
                ), "Cluster removed resources from running trials!"

        self._avail_resources = avail_resources
        self._last_resource_refresh = time.time()
        self._resources_initialized = True

    @property
    def idle_resources(self) -> Resources:
        return Resources.subtract(self._avail_resources,
                                  self._committed_resources)

    def _commit_resources(self, trial: Trial):
        resources = trial.resources
        self._trials_running.add(trial)

        committed = self._committed_resources
        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) + resources.get_res_total(k)
            for k in all_keys
        }

        self._committed_resources = Resources(
            committed.cpu + resources.cpu_total(),
            committed.gpu + resources.gpu_total(),
            committed.memory + resources.memory_total(),
            committed.object_store_memory +
            resources.object_store_memory_total(),
            custom_resources=custom_resources,
        )
        logger.debug(
            f"Committed res={resources} -> {self._committed_resources}")

    def _return_resources(self, trial: Trial):
        if trial not in self._trials_running:
            return
        logger.debug("Trial %s: Returning resources.", trial)
        self._trials_running.remove(trial)
        resources = trial.resources

        committed = self._committed_resources

        all_keys = set(resources.custom_resources).union(
            set(committed.custom_resources))

        custom_resources = {
            k: committed.get(k) - resources.get_res_total(k)
            for k in all_keys
        }
        self._committed_resources = Resources(
            committed.cpu - resources.cpu_total(),
            committed.gpu - resources.gpu_total(),
            custom_resources=custom_resources,
        )

        assert (self._committed_resources.is_nonnegative()
                ), "Resource invalid: {} - {} = {}".format(
                    committed, resources, self._committed_resources)

    def on_no_available_trials(self, trial_runner):
        """This is called when we get all trial from a batch from the search algo"""
        logger.debug("on_no_available_trials")
        self._detect_groups()
        super().on_no_available_trials(trial_runner)
示例#5
0
 def testResourceNumericalError(self):
     resource = Resources(cpu=0.99, gpu=0.99, custom_resources={"a": 0.99})
     small_resource = Resources(cpu=0.33, gpu=0.33, custom_resources={"a": 0.33})
     for i in range(3):
         resource = Resources.subtract(resource, small_resource)
     self.assertTrue(resource.is_nonnegative())