Ejemplo n.º 1
0
def test_result_grid_future_checkpoint(ray_start_2_cpus, to_object):
    trainable_cls = get_trainable_cls("__fake")
    trial = Trial("__fake", stub=True)
    trial.config = {"some_config": 1}
    trial.last_result = {"some_result": 2, "config": trial.config}

    trainable = ray.remote(trainable_cls).remote()
    ray.get(trainable.set_info.remote({"info": 4}))

    if to_object:
        checkpoint_data = trainable.save_to_object.remote()
    else:
        checkpoint_data = trainable.save.remote()

    trial.on_checkpoint(
        _TrackedCheckpoint(checkpoint_data,
                           storage_mode=CheckpointStorage.MEMORY))
    trial.pickled_error_file = None
    trial.error_file = None
    result_grid = ResultGrid(None)

    # Internal result grid conversion
    result = result_grid._trial_to_result(trial)
    assert isinstance(result.checkpoint, Checkpoint)
    assert isinstance(result.metrics, dict)
    assert isinstance(result.config, dict)
    assert result.metrics_dataframe is None
    assert result.config == {"some_config": 1}
    assert result.metrics["config"] == result.config

    # Load checkpoint data (see ray.rllib.algorithms.mock.MockTrainer definition)
    with result.checkpoint.as_directory() as checkpoint_dir:
        with open(os.path.join(checkpoint_dir, "mock_agent.pkl"), "rb") as f:
            info = pickle.load(f)
            assert info["info"] == 4
Ejemplo n.º 2
0
    def _exploit(
        self,
        trial_executor: "trial_runner.RayTrialExecutor",
        trial: Trial,
        trial_to_clone: Trial,
    ):
        """Transfers perturbed state from trial_to_clone -> trial.

        If specified, also logs the updated hyperparam state.
        """
        trial_state = self._trial_state[trial]
        new_state = self._trial_state[trial_to_clone]
        logger.info("[exploit] transferring weights from trial "
                    "{} (score {}) -> {} (score {})".format(
                        trial_to_clone, new_state.last_score, trial,
                        trial_state.last_score))

        new_config = self._get_new_config(trial, trial_to_clone)

        # Only log mutated hyperparameters and not entire config.
        old_hparams = {
            k: v
            for k, v in trial_to_clone.config.items()
            if k in self._hyperparam_mutations
        }
        new_hparams = {
            k: v
            for k, v in new_config.items() if k in self._hyperparam_mutations
        }
        logger.info("[explore] perturbed config from {} -> {}".format(
            old_hparams, new_hparams))

        if self._log_config:
            self._log_config_on_step(trial_state, new_state, trial,
                                     trial_to_clone, new_config)

        new_tag = _make_experiment_tag(trial_state.orig_tag, new_config,
                                       self._hyperparam_mutations)
        if trial.status == Trial.PAUSED:
            # If trial is paused we update it with a new checkpoint.
            # When the trial is started again, the new checkpoint is used.
            if not self._synch:
                raise TuneError("Trials should be paused here only if in "
                                "synchronous mode. If you encounter this error"
                                " please raise an issue on Ray Github.")
        else:
            trial_executor.stop_trial(trial)
            trial_executor.set_status(trial, Trial.PAUSED)
        trial.set_experiment_tag(new_tag)
        trial.set_config(new_config)
        trial.on_checkpoint(new_state.last_checkpoint)

        self._num_perturbations += 1
        # Transfer over the last perturbation time as well
        trial_state.last_perturbation_time = new_state.last_perturbation_time
        trial_state.last_train_time = new_state.last_train_time
Ejemplo n.º 3
0
    def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
                        trial: Trial, result: Dict) -> str:
        if TRAINING_ITERATION not in result:
            # No time reported
            return TrialScheduler.CONTINUE

        if not self._next_policy:
            # No more changes in the config
            return TrialScheduler.CONTINUE

        step = result[TRAINING_ITERATION]
        self._current_step = step

        change_at, new_config = self._next_policy

        if step < change_at:
            # Don't change the policy just yet
            return TrialScheduler.CONTINUE

        logger.info("Population Based Training replay is now at step {}. "
                    "Configuration will be changed to {}.".format(
                        step, new_config))

        checkpoint = trial_runner.trial_executor.save(trial,
                                                      CheckpointStorage.MEMORY,
                                                      result=result)

        new_tag = _make_experiment_tag(self.experiment_tag, new_config,
                                       new_config)

        trial_executor = trial_runner.trial_executor
        trial_executor.stop_trial(trial)
        trial_executor.set_status(trial, Trial.PAUSED)
        trial.set_experiment_tag(new_tag)
        trial.set_config(new_config)
        trial.on_checkpoint(checkpoint)

        self.current_config = new_config
        self._num_perturbations += 1
        self._next_policy = next(self._policy_iter, None)

        return TrialScheduler.NOOP