def _exploit_trial(self, trial_executor: RayTrialExecutor, trial: Trial, trial_to_clone: Trial): """ Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trials_states_dict[trial] new_state = self._trials_states_dict[trial_to_clone] if not new_state.last_checkpoint: logger.info( "[pbt]: no checkpoint for trial. Skip exploit for Trial {}". format(trial)) return new_config = explore(trial_to_clone.config, self._hyperparam_mutations, self._hyperparam_mutate_probability, self._explore_func) logger.info( "[exploit] transferring weights from trial {} (score {}) -> {} (score {})" .format(trial_to_clone, new_state.last_score, trial, trial_state.last_score)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, trial_to_clone, new_config) new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) if reset_successful: trial_executor.restore( trial, Checkpoint.from_object(new_state.last_checkpoint)) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( trial, Checkpoint.from_object(new_state.last_checkpoint)) # TODO: move to Exploiter new_state.num_steps = 0 trial_state.num_steps = 0 new_state.num_explorations = 0 trial_state.num_explorations += 1 self._num_explorations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict) -> str: if TRAINING_ITERATION not in result: # No time reported return TrialScheduler.CONTINUE if not self._next_policy: # No more changes in the config return TrialScheduler.CONTINUE step = result[TRAINING_ITERATION] self._current_step = step change_at, new_config = self._next_policy if step < change_at: # Don't change the policy just yet return TrialScheduler.CONTINUE logger.info("Population Based Training replay is now at step {}. " "Configuration will be changed to {}.".format( step, new_config)) checkpoint = trial_runner.trial_executor.save(trial, Checkpoint.MEMORY, result=result) new_tag = make_experiment_tag(self.experiment_tag, new_config, new_config) trial_executor = trial_runner.trial_executor reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) if reset_successful: trial_executor.restore(trial, checkpoint, block=True) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial(trial, checkpoint, train=False) self.current_config = new_config self._num_perturbations += 1 self._next_policy = next(self._policy_iter, None) return TrialScheduler.CONTINUE
def reset_trial(self, trial: Trial, new_config, new_experiment_tag): """Tries to invoke `Trainable.reset_config()` to reset trial. Args: trial (Trial): Trial to be reset. new_config (dict): New configuration for Trial trainable. new_experiment_tag (str): New experiment name for trial. Returns: True if `reset_config` is successful else False. """ logger.debug("reset_trial %s", trial) trial.experiment_tag = new_experiment_tag trial.config = new_config trainable = trial.runner with _change_working_directory(trial): try: reset_val = ray.get(trainable.reset_config.remote(new_config), DEFAULT_GET_TIMEOUT) except RayTimeoutError: logger.exception("Trial %s: reset_config timed out.", trial) return False return reset_val
def _exploit(self, trial_executor: "trial_executor.TrialExecutor", trial: Trial, trial_to_clone: Trial): """Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] logger.info("[exploit] transferring weights from trial " "{} (score {}) -> {} (score {})".format( trial_to_clone, new_state.last_score, trial, trial_state.last_score)) new_config = self._get_new_config(trial, trial_to_clone) # Only log mutated hyperparameters and not entire config. old_hparams = { k: v for k, v in trial_to_clone.config.items() if k in self._hyperparam_mutations } new_hparams = { k: v for k, v in new_config.items() if k in self._hyperparam_mutations } logger.info("[explore] perturbed config from {} -> {}".format( old_hparams, new_hparams)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, trial_to_clone, new_config) new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) if trial.status == Trial.PAUSED: # If trial is paused we update it with a new checkpoint. # When the trial is started again, the new checkpoint is used. if not self._synch: raise TuneError("Trials should be paused here only if in " "synchronous mode. If you encounter this error" " please raise an issue on Ray Github.") trial.config = new_config trial.experiment_tag = new_tag trial.on_checkpoint(new_state.last_checkpoint) else: # If trial is running, we first try to reset it. # If that is unsuccessful, then we have to stop it and start it # again with a new checkpoint. reset_successful = trial_executor.reset_trial( trial, new_config, new_tag) # TODO(ujvl): Refactor Scheduler abstraction to abstract # mechanism for trial restart away. We block on restore # and suppress train on start as a stop-gap fix to # https://github.com/ray-project/ray/issues/7258. if reset_successful: trial_executor.restore( trial, new_state.last_checkpoint, block=True) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( trial, new_state.last_checkpoint, train=False) self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time trial_state.last_train_time = new_state.last_train_time
def _exploit(self, trial_executor: "trial_executor.TrialExecutor", trial: Trial, trial_to_clone: Trial): """Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] if not new_state.last_checkpoint: logger.info("[pbt]: no checkpoint for trial." " Skip exploit for Trial {}".format(trial)) return new_config = explore(trial_to_clone.config, self._hyperparam_mutations, self._resample_probability, self._custom_explore_fn) # todo : 여기에서 new_config를 하고 변경된 파라메터를 돌리는작업이 필요함 old_bad_config = trial_to_clone.config old_good_config = trial.config print('-------------------------------- perturbation---------------------------') print(f'trial_to_clone.trial_id = {trial_to_clone.trial_id}') print(f'trial.trial_id = {trial.trial_id}') if self._ucb is not None: if self._ucb.is_need_to_reflect_reward(): # worker가 늦게 돌아가는 경우가 있어서 last_result가 None인 경우를 배제하고 평균값으로 대체!! score = np.average([t.last_result[self._metric] for t in self._trial_state if self._metric in t.last_result.keys()]) self._ucb.reflect_reward(score) selected = self._ucb.pull() masks = self._ucb.bitfield(selected) print(f'explore!!!!!!! ucb_state n: {self._ucb.n}, selected : {self._ucb.selected}, masks : {masks}') for i in range(self._ucb.n_params): if masks[i] == 0: key = list(new_config.keys())[i] new_config[key] = old_good_config[key] # todo: perturb취소하는 로직 추가 print(new_config) if self._ucb is not None: logger.error("[explore] perturbed ucb config from {} -> {}".format( old_good_config, new_config)) # logger.info("[explore] perturbed ucb config from {} -> {}".format( # old_good_config, new_config)) logger.info("[exploit] transferring weights from trial " "{} (score {}) -> {} (score {})".format( trial_to_clone, new_state.last_score, trial, trial_state.last_score)) # Only log mutated hyperparameters and not entire config. old_hparams = { k: v for k, v in trial_to_clone.config.items() if k in self._hyperparam_mutations } new_hparams = { k: v for k, v in new_config.items() if k in self._hyperparam_mutations } logger.info("[explore] perturbed config from {} -> {}".format( old_hparams, new_hparams)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, trial_to_clone, new_config) new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) if trial.status == Trial.PAUSED: # If trial is paused we update it with a new checkpoint. # When the trial is started again, the new checkpoint is used. if not self._synch: raise TuneError("Trials should be paused here only if in " "synchronous mode. If you encounter this error" " please raise an issue on Ray Github.") trial.config = new_config trial.experiment_tag = new_tag trial.on_checkpoint(new_state.last_checkpoint) else: # If trial is running, we first try to reset it. # If that is unsuccessful, then we have to stop it and start it # again with a new checkpoint. reset_successful = trial_executor.reset_trial( trial, new_config, new_tag) # TODO(ujvl): Refactor Scheduler abstraction to abstract # mechanism for trial restart away. We block on restore # and suppress train on start as a stop-gap fix to # https://github.com/ray-project/ray/issues/7258. if reset_successful: trial_executor.restore( trial, new_state.last_checkpoint, block=True) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( trial, new_state.last_checkpoint, train=False) self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time trial_state.last_train_time = new_state.last_train_time