def _exploit_trial(self, trial_executor: RayTrialExecutor, trial: Trial, trial_to_clone: Trial): """ Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trials_states_dict[trial] new_state = self._trials_states_dict[trial_to_clone] if not new_state.last_checkpoint: logger.info( "[pbt]: no checkpoint for trial. Skip exploit for Trial {}". format(trial)) return new_config = explore(trial_to_clone.config, self._hyperparam_mutations, self._hyperparam_mutate_probability, self._explore_func) logger.info( "[exploit] transferring weights from trial {} (score {}) -> {} (score {})" .format(trial_to_clone, new_state.last_score, trial, trial_state.last_score)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, trial_to_clone, new_config) new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) if reset_successful: trial_executor.restore( trial, Checkpoint.from_object(new_state.last_checkpoint)) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( trial, Checkpoint.from_object(new_state.last_checkpoint)) # TODO: move to Exploiter new_state.num_steps = 0 trial_state.num_steps = 0 new_state.num_explorations = 0 trial_state.num_explorations += 1 self._num_explorations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time
def testPerturbationValues(self): def assertProduces(fn, values): random.seed(0) seen = set() for _ in range(100): seen.add(fn()["v"]) self.assertEqual(seen, values) # Categorical case assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 8}) assertProduces( lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4}) assertProduces( lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {8, 10}) assertProduces( lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4, 8, 10}) assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), {3, 4, 8, 10}) # Continuous case assertProduces( lambda: explore( {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x), {80, 120}) assertProduces( lambda: explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x), {80.0, 120.0}) assertProduces( lambda: explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0, lambda x: x), {10.0, 100.0})
def testPerturbationValues(self): def assertProduces(fn, values): random.seed(0) seen = set() for _ in range(100): seen.add(fn()["v"]) self.assertEqual(seen, values) # Categorical case assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 8}) assertProduces( lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4}) assertProduces( lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {8, 10}) assertProduces( lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4, 8, 10}) assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), {3, 4, 8, 10}) # Continuous case assertProduces( lambda: explore( {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x), {80, 120}) assertProduces( lambda: explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0, lambda x: x), {80.0, 120.0}) assertProduces( lambda: explore( {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0, lambda x: x), {10.0, 100.0}) def deep_add(seen, new_values): for k, new_value in new_values.items(): if isinstance(new_value, dict): if k not in seen: seen[k] = {} seen[k].update(deep_add(seen[k], new_value)) else: if k not in seen: seen[k] = set() seen[k].add(new_value) return seen def assertNestedProduces(fn, values): random.seed(0) seen = {} for _ in range(100): new_config = fn() seen = deep_add(seen, new_config) self.assertEqual(seen, values) # Nested mutation and spec assertNestedProduces( lambda: explore( { "a": { "b": 4 }, "1": { "2": { "3": 100 } }, }, { "a": { "b": [3, 4, 8, 10] }, "1": { "2": { "3": lambda: random.choice([10, 100]) } }, }, 0.0, lambda x: x), { "a": { "b": {3, 8} }, "1": { "2": { "3": {80, 120} } }, }) custom_explore_fn = MagicMock(side_effect=lambda x: x) # Nested mutation and spec assertNestedProduces( lambda: explore( { "a": { "b": 4 }, "1": { "2": { "3": 100 } }, }, { "a": { "b": [3, 4, 8, 10] }, "1": { "2": { "3": lambda: random.choice([10, 100]) } }, }, 0.0, custom_explore_fn), { "a": { "b": {3, 8} }, "1": { "2": { "3": {80, 120} } }, }) # Expect call count to be 100 because we call explore 100 times self.assertEqual(custom_explore_fn.call_count, 100)
def testPerturbationValues(self): def assertProduces(fn, values): random.seed(0) seen = set() for _ in range(100): seen.add(fn()["v"]) self.assertEqual(seen, values) # Categorical case assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 8}) assertProduces( lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4}) assertProduces( lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {8, 10}) assertProduces( lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x), {3, 4, 8, 10}) assertProduces( lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x), {3, 4, 8, 10}) # Continuous case assertProduces( lambda: explore({"v": 100}, { "v": lambda: random.choice([10, 100]) }, 0.0, lambda x: x), {80, 120}) assertProduces( lambda: explore({"v": 100.0}, { "v": lambda: random.choice([10, 100]) }, 0.0, lambda x: x), {80.0, 120.0}) assertProduces( lambda: explore({"v": 100.0}, { "v": lambda: random.choice([10, 100]) }, 1.0, lambda x: x), {10.0, 100.0}) def deep_add(seen, new_values): for k, new_value in new_values.items(): if isinstance(new_value, dict): if k not in seen: seen[k] = {} seen[k].update(deep_add(seen[k], new_value)) else: if k not in seen: seen[k] = set() seen[k].add(new_value) return seen def assertNestedProduces(fn, values): random.seed(0) seen = {} for _ in range(100): new_config = fn() seen = deep_add(seen, new_config) self.assertEqual(seen, values) # Nested mutation and spec assertNestedProduces( lambda: explore({ "a": { "b": 4 }, "1": { "2": { "3": 100 } }, }, { "a": { "b": [3, 4, 8, 10] }, "1": { "2": { "3": lambda: random.choice([10, 100]) } }, }, 0.0, lambda x: x), { "a": { "b": {3, 8} }, "1": { "2": { "3": {80, 120} } }, }) custom_explore_fn = MagicMock(side_effect=lambda x: x) # Nested mutation and spec assertNestedProduces( lambda: explore({ "a": { "b": 4 }, "1": { "2": { "3": 100 } }, }, { "a": { "b": [3, 4, 8, 10] }, "1": { "2": { "3": lambda: random.choice([10, 100]) } }, }, 0.0, custom_explore_fn), { "a": { "b": {3, 8} }, "1": { "2": { "3": {80, 120} } }, }) # Expect call count to be 100 because we call explore 100 times self.assertEqual(custom_explore_fn.call_count, 100)