Exemplo n.º 1
0
Arquivo: pbt.py Projeto: tchordia/ray
    def _exploit(
        self,
        trial_executor: "trial_executor.TrialExecutor",
        trial: Trial,
        trial_to_clone: Trial,
    ):
        """Transfers perturbed state from trial_to_clone -> trial.

        If specified, also logs the updated hyperparam state.
        """
        trial_state = self._trial_state[trial]
        new_state = self._trial_state[trial_to_clone]
        logger.info("[exploit] transferring weights from trial "
                    "{} (score {}) -> {} (score {})".format(
                        trial_to_clone, new_state.last_score, trial,
                        trial_state.last_score))

        new_config = self._get_new_config(trial, trial_to_clone)

        # Only log mutated hyperparameters and not entire config.
        old_hparams = {
            k: v
            for k, v in trial_to_clone.config.items()
            if k in self._hyperparam_mutations
        }
        new_hparams = {
            k: v
            for k, v in new_config.items() if k in self._hyperparam_mutations
        }
        logger.info("[explore] perturbed config from {} -> {}".format(
            old_hparams, new_hparams))

        if self._log_config:
            self._log_config_on_step(trial_state, new_state, trial,
                                     trial_to_clone, new_config)

        new_tag = _make_experiment_tag(trial_state.orig_tag, new_config,
                                       self._hyperparam_mutations)
        if trial.status == Trial.PAUSED:
            # If trial is paused we update it with a new checkpoint.
            # When the trial is started again, the new checkpoint is used.
            if not self._synch:
                raise TuneError("Trials should be paused here only if in "
                                "synchronous mode. If you encounter this error"
                                " please raise an issue on Ray Github.")
        else:
            trial_executor.stop_trial(trial)
            trial_executor.set_status(trial, Trial.PAUSED)
        trial.set_experiment_tag(new_tag)
        trial.set_config(new_config)
        trial.on_checkpoint(new_state.last_checkpoint)

        self._num_perturbations += 1
        # Transfer over the last perturbation time as well
        trial_state.last_perturbation_time = new_state.last_perturbation_time
        trial_state.last_train_time = new_state.last_train_time
Exemplo n.º 2
0
Arquivo: pbt.py Projeto: eggie5/ray
    def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
                        trial: Trial, result: Dict) -> str:
        if TRAINING_ITERATION not in result:
            # No time reported
            return TrialScheduler.CONTINUE

        if not self._next_policy:
            # No more changes in the config
            return TrialScheduler.CONTINUE

        step = result[TRAINING_ITERATION]
        self._current_step = step

        change_at, new_config = self._next_policy

        if step < change_at:
            # Don't change the policy just yet
            return TrialScheduler.CONTINUE

        logger.info("Population Based Training replay is now at step {}. "
                    "Configuration will be changed to {}.".format(
                        step, new_config))

        checkpoint = trial_runner.trial_executor.save(trial,
                                                      Checkpoint.MEMORY,
                                                      result=result)

        new_tag = make_experiment_tag(self.experiment_tag, new_config,
                                      new_config)

        trial_executor = trial_runner.trial_executor
        reset_successful = trial_executor.reset_trial(trial, new_config,
                                                      new_tag)

        if reset_successful:
            trial_executor.restore(trial, checkpoint, block=True)
        else:
            trial_executor.stop_trial(trial,
                                      destroy_pg_if_cannot_replace=False)
            trial.set_experiment_tag(new_tag)
            trial.set_config(new_config)
            trial_executor.start_trial(trial, checkpoint, train=False)

        self.current_config = new_config
        self._num_perturbations += 1
        self._next_policy = next(self._policy_iter, None)

        return TrialScheduler.CONTINUE
Exemplo n.º 3
0
Arquivo: pbt.py Projeto: smorad/ray
    def on_trial_result(
        self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict
    ) -> str:
        if TRAINING_ITERATION not in result:
            # No time reported
            return TrialScheduler.CONTINUE

        if not self._next_policy:
            # No more changes in the config
            return TrialScheduler.CONTINUE

        step = result[TRAINING_ITERATION]
        self._current_step = step

        change_at, new_config = self._next_policy

        if step < change_at:
            # Don't change the policy just yet
            return TrialScheduler.CONTINUE

        logger.info(
            "Population Based Training replay is now at step {}. "
            "Configuration will be changed to {}.".format(step, new_config)
        )

        checkpoint = trial_runner.trial_executor.save(
            trial, _TuneCheckpoint.MEMORY, result=result
        )

        new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config)

        trial_executor = trial_runner.trial_executor
        trial_executor.stop_trial(trial)
        trial_executor.set_status(trial, Trial.PAUSED)
        trial.set_experiment_tag(new_tag)
        trial.set_config(new_config)
        trial.on_checkpoint(checkpoint)

        self.current_config = new_config
        self._num_perturbations += 1
        self._next_policy = next(self._policy_iter, None)

        return TrialScheduler.NOOP
Exemplo n.º 4
0
    def _exploit(self, trial_executor: "trial_executor.TrialExecutor",
                 trial: Trial, trial_to_clone: Trial):
        """Transfers perturbed state from trial_to_clone -> trial.

        If specified, also logs the updated hyperparam state.
        """
        trial_state = self._trial_state[trial]
        new_state = self._trial_state[trial_to_clone]
        logger.info("[exploit] transferring weights from trial "
                    "{} (score {}) -> {} (score {})".format(
                        trial_to_clone, new_state.last_score, trial,
                        trial_state.last_score))

        new_config = self._get_new_config(trial, trial_to_clone)

        # Only log mutated hyperparameters and not entire config.
        old_hparams = {
            k: v
            for k, v in trial_to_clone.config.items()
            if k in self._hyperparam_mutations
        }
        new_hparams = {
            k: v
            for k, v in new_config.items() if k in self._hyperparam_mutations
        }
        logger.info("[explore] perturbed config from {} -> {}".format(
            old_hparams, new_hparams))

        if self._log_config:
            self._log_config_on_step(trial_state, new_state, trial,
                                     trial_to_clone, new_config)

        new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
                                      self._hyperparam_mutations)
        if trial.status == Trial.PAUSED:
            # If trial is paused we update it with a new checkpoint.
            # When the trial is started again, the new checkpoint is used.
            if not self._synch:
                raise TuneError("Trials should be paused here only if in "
                                "synchronous mode. If you encounter this error"
                                " please raise an issue on Ray Github.")
            trial.set_experiment_tag(new_tag)
            trial.set_config(new_config)
            trial.on_checkpoint(new_state.last_checkpoint)
        else:
            # If trial is running, we first try to reset it.
            # If that is unsuccessful, then we have to stop it and start it
            # again with a new checkpoint.
            reset_successful = trial_executor.reset_trial(
                trial, new_config, new_tag)
            # TODO(ujvl): Refactor Scheduler abstraction to abstract
            #  mechanism for trial restart away. We block on restore
            #  and suppress train on start as a stop-gap fix to
            #  https://github.com/ray-project/ray/issues/7258.
            if reset_successful:
                trial_executor.restore(trial,
                                       new_state.last_checkpoint,
                                       block=True)
            else:
                trial_executor.stop_trial(trial)
                trial.set_experiment_tag(new_tag)
                trial.set_config(new_config)
                trial_executor.start_trial(trial,
                                           new_state.last_checkpoint,
                                           train=False)

        self._num_perturbations += 1
        # Transfer over the last perturbation time as well
        trial_state.last_perturbation_time = new_state.last_perturbation_time
        trial_state.last_train_time = new_state.last_train_time
Exemplo n.º 5
0
    def _exploit(self, trial_executor: "trial_executor.TrialExecutor",
                 trial: Trial, trial_to_clone: Trial):
        """Transfers perturbed state from trial_to_clone -> trial.

        If specified, also logs the updated hyperparam state.
        """
        trial_state = self._trial_state[trial]
        new_state = self._trial_state[trial_to_clone]
        if not new_state.last_checkpoint:
            logger.info("[pbt]: no checkpoint for trial."
                        " Skip exploit for Trial {}".format(trial))
            return

        new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
                             self._resample_probability,
                             self._custom_explore_fn)

        # todo : 여기에서 new_config를 하고 변경된 파라메터를 돌리는작업이 필요함
        old_bad_config = trial_to_clone.config
        old_good_config = trial.config

        print('-------------------------------- perturbation---------------------------')
        print(f'trial_to_clone.trial_id = {trial_to_clone.trial_id}')
        print(f'trial.trial_id = {trial.trial_id}')

        if self._ucb is not None:
            if self._ucb.is_need_to_reflect_reward():
                # worker가 늦게 돌아가는 경우가 있어서 last_result가 None인 경우를 배제하고 평균값으로 대체!!
                score = np.average([t.last_result[self._metric] for t in self._trial_state if self._metric in t.last_result.keys()])
                self._ucb.reflect_reward(score)

            selected = self._ucb.pull()
            masks = self._ucb.bitfield(selected)
            print(f'explore!!!!!!! ucb_state n: {self._ucb.n}, selected : {self._ucb.selected}, masks : {masks}')

            for i in range(self._ucb.n_params):
                if masks[i] == 0:
                    key = list(new_config.keys())[i]
                    new_config[key] = old_good_config[key]

        # todo: perturb취소하는 로직 추가
        print(new_config)


        if self._ucb is not None:
            logger.error("[explore] perturbed ucb config from {} -> {}".format(
                old_good_config, new_config))

        # logger.info("[explore] perturbed ucb config from {} -> {}".format(
        #     old_good_config, new_config))
        

        logger.info("[exploit] transferring weights from trial "
                    "{} (score {}) -> {} (score {})".format(
                        trial_to_clone, new_state.last_score, trial,
                        trial_state.last_score))
        # Only log mutated hyperparameters and not entire config.
        old_hparams = {
            k: v
            for k, v in trial_to_clone.config.items()
            if k in self._hyperparam_mutations
        }
        new_hparams = {
            k: v
            for k, v in new_config.items() if k in self._hyperparam_mutations
        }
        logger.info("[explore] perturbed config from {} -> {}".format(
            old_hparams, new_hparams))

        if self._log_config:
            self._log_config_on_step(trial_state, new_state, trial,
                                     trial_to_clone, new_config)

        new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
                                      self._hyperparam_mutations)
        if trial.status == Trial.PAUSED:
            # If trial is paused we update it with a new checkpoint.
            # When the trial is started again, the new checkpoint is used.
            if not self._synch:
                raise TuneError("Trials should be paused here only if in "
                                "synchronous mode. If you encounter this error"
                                " please raise an issue on Ray Github.")
            trial.config = new_config
            trial.experiment_tag = new_tag
            trial.on_checkpoint(new_state.last_checkpoint)
        else:
            # If trial is running, we first try to reset it.
            # If that is unsuccessful, then we have to stop it and start it
            # again with a new checkpoint.
            reset_successful = trial_executor.reset_trial(
                trial, new_config, new_tag)
            # TODO(ujvl): Refactor Scheduler abstraction to abstract
            #  mechanism for trial restart away. We block on restore
            #  and suppress train on start as a stop-gap fix to
            #  https://github.com/ray-project/ray/issues/7258.
            if reset_successful:
                trial_executor.restore(
                    trial, new_state.last_checkpoint, block=True)
            else:
                trial_executor.stop_trial(trial, stop_logger=False)
                trial.config = new_config
                trial.experiment_tag = new_tag
                trial_executor.start_trial(
                    trial, new_state.last_checkpoint, train=False)

        self._num_perturbations += 1
        # Transfer over the last perturbation time as well
        trial_state.last_perturbation_time = new_state.last_perturbation_time
        trial_state.last_train_time = new_state.last_train_time