Exemplo n.º 1
0
    def restore(self, trial: Trial) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial: The trial to be restored.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        checkpoint = trial.checkpoint
        if checkpoint.dir_or_data is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        checkpoint_dir = checkpoint.dir_or_data
        node_ip = checkpoint.node_ip
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(checkpoint_dir)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial,
                         checkpoint_dir)
            if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint
                    or not os.path.exists(checkpoint_dir)):
                # If using cloud checkpointing, trial will get cp from cloud.
                # If not syncing to driver, assume it has access to the cp
                # on the local fs.
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(
                        checkpoint_dir, node_ip)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where no cloud checkpoints are provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                checkpoint_path = TrainableUtil.find_checkpoint_dir(
                    checkpoint_dir)
                obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise _AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` for remote "
                    "storage-based restoration")

            self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT,
                                     trial)
            trial.restoring_from = checkpoint
Exemplo n.º 2
0
    def _setup_remote_runner(self, trial):
        trial.init_logdir()
        # We checkpoint metadata here to try mitigating logdir duplication
        self._trials_to_cache.add(trial)
        logger_creator = partial(noop_logger_creator, logdir=trial.logdir)

        if len(self._cached_actor_pg) > 0:
            assert self._reuse_actors
            existing_runner, pg = self._cached_actor_pg.popleft()
            logger.debug(f"Trial {trial}: Reusing cached runner "
                         f"{existing_runner}")

            trial.set_runner(existing_runner)
            if pg:
                self._pg_manager.assign_cached_pg(pg, trial)

            if not self.reset_trial(trial, trial.config, trial.experiment_tag,
                                    logger_creator):
                raise _AbortTrialExecution(
                    "Trainable runner reuse requires reset_config() to be "
                    "implemented and return True.")
            return existing_runner

        trainable_cls = trial.get_trainable_cls()
        if not trainable_cls:
            raise _AbortTrialExecution(
                f"Invalid trainable: {trial.trainable_name}. If you passed "
                f"a string, make sure the trainable was registered before.")
        _actor_cls = _class_cache.get(trainable_cls)

        if not self._pg_manager.has_ready(trial):
            return None

        full_actor_class = self._pg_manager.get_full_actor_cls(
            trial, _actor_cls)
        # Clear the Trial's location (to be updated later on result)
        # since we don't know where the remote runner is placed.
        trial.set_location(_Location())
        logger.debug("Trial %s: Setting up new remote runner.", trial)
        # Logging for trials is handled centrally by TrialRunner, so
        # configure the remote runner to use a noop-logger.
        trial_config = copy.deepcopy(trial.config)
        trial_config[TRIAL_INFO] = _TrialInfo(trial)

        stdout_file, stderr_file = trial.log_to_file
        trial_config[STDOUT_FILE] = stdout_file
        trial_config[STDERR_FILE] = stderr_file
        kwargs = {
            "config": trial_config,
            "logger_creator": logger_creator,
        }
        if trial.uses_cloud_checkpointing:
            # We keep these kwargs separate for backwards compatibility
            # with trainables that don't provide these keyword arguments
            kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
            kwargs["custom_syncer"] = trial.custom_syncer

            # Throw a meaningful error if trainable does not use the
            # new API
            sig = inspect.signature(trial.get_trainable_cls())
            try:
                sig.bind_partial(**kwargs)
            except Exception as e:
                raise RuntimeError(
                    "Your trainable class does not accept a "
                    "`remote_checkpoint_dir` or `custom_syncer` argument "
                    "in its constructor, but you've passed a "
                    "`upload_dir` to your SyncConfig. Without accepting "
                    "these parameters and passing them to the base trainable "
                    "constructor in the init call, cloud checkpointing is "
                    "effectively disabled. To resolve this issue, add the "
                    "parameters to your trainable class constructor or "
                    "disable cloud checkpointing by setting `upload_dir=None`."
                ) from e

        with self._change_working_directory(trial):
            return full_actor_class.remote(**kwargs)