Example #1
0
    def _exploit_trial(self, trial_executor: RayTrialExecutor, trial: Trial,
                       trial_to_clone: Trial):
        """
        Transfers perturbed state from trial_to_clone -> trial.
        If specified, also logs the updated hyperparam state.
        """

        trial_state = self._trials_states_dict[trial]
        new_state = self._trials_states_dict[trial_to_clone]
        if not new_state.last_checkpoint:
            logger.info(
                "[pbt]: no checkpoint for trial. Skip exploit for Trial {}".
                format(trial))
            return
        new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
                             self._hyperparam_mutate_probability,
                             self._explore_func)
        logger.info(
            "[exploit] transferring weights from trial {} (score {}) -> {} (score {})"
            .format(trial_to_clone, new_state.last_score, trial,
                    trial_state.last_score))

        if self._log_config:
            self._log_config_on_step(trial_state, new_state, trial,
                                     trial_to_clone, new_config)

        new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
                                      self._hyperparam_mutations)
        reset_successful = trial_executor.reset_trial(trial, new_config,
                                                      new_tag)

        if reset_successful:
            trial_executor.restore(
                trial, Checkpoint.from_object(new_state.last_checkpoint))
        else:
            trial_executor.stop_trial(trial, stop_logger=False)
            trial.config = new_config
            trial.experiment_tag = new_tag
            trial_executor.start_trial(
                trial, Checkpoint.from_object(new_state.last_checkpoint))

        # TODO: move to Exploiter
        new_state.num_steps = 0
        trial_state.num_steps = 0
        new_state.num_explorations = 0
        trial_state.num_explorations += 1
        self._num_explorations += 1
        # Transfer over the last perturbation time as well
        trial_state.last_perturbation_time = new_state.last_perturbation_time
Example #2
0
    def testPerturbationValues(self):
        def assertProduces(fn, values):
            random.seed(0)
            seen = set()
            for _ in range(100):
                seen.add(fn()["v"])
            self.assertEqual(seen, values)

        # Categorical case
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 8})
        assertProduces(
            lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4})
        assertProduces(
            lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {8, 10})
        assertProduces(
            lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4, 8, 10})
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x),
            {3, 4, 8, 10})

        # Continuous case
        assertProduces(
            lambda: explore(
                {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0,
                lambda x: x),
            {80, 120})
        assertProduces(
            lambda: explore(
                {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0,
                lambda x: x),
            {80.0, 120.0})
        assertProduces(
            lambda: explore(
                {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0,
                lambda x: x),
            {10.0, 100.0})
Example #3
0
    def testPerturbationValues(self):
        def assertProduces(fn, values):
            random.seed(0)
            seen = set()
            for _ in range(100):
                seen.add(fn()["v"])
            self.assertEqual(seen, values)

        # Categorical case
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 8})
        assertProduces(
            lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4})
        assertProduces(
            lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {8, 10})
        assertProduces(
            lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4, 8, 10})
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x),
            {3, 4, 8, 10})

        # Continuous case
        assertProduces(
            lambda: explore(
                {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0,
                lambda x: x),
            {80, 120})
        assertProduces(
            lambda: explore(
                {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0,
                lambda x: x),
            {80.0, 120.0})
        assertProduces(
            lambda: explore(
                {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0,
                lambda x: x),
            {10.0, 100.0})

        def deep_add(seen, new_values):
            for k, new_value in new_values.items():
                if isinstance(new_value, dict):
                    if k not in seen:
                        seen[k] = {}
                    seen[k].update(deep_add(seen[k], new_value))
                else:
                    if k not in seen:
                        seen[k] = set()
                    seen[k].add(new_value)

            return seen

        def assertNestedProduces(fn, values):
            random.seed(0)
            seen = {}
            for _ in range(100):
                new_config = fn()
                seen = deep_add(seen, new_config)
            self.assertEqual(seen, values)

        # Nested mutation and spec
        assertNestedProduces(
            lambda: explore(
                {
                    "a": {
                        "b": 4
                    },
                    "1": {
                        "2": {
                            "3": 100
                        }
                    },
                },
                {
                    "a": {
                        "b": [3, 4, 8, 10]
                    },
                    "1": {
                        "2": {
                            "3": lambda: random.choice([10, 100])
                        }
                    },
                },
                0.0,
                lambda x: x),
            {
                "a": {
                    "b": {3, 8}
                },
                "1": {
                    "2": {
                        "3": {80, 120}
                    }
                },
            })

        custom_explore_fn = MagicMock(side_effect=lambda x: x)

        # Nested mutation and spec
        assertNestedProduces(
            lambda: explore(
                {
                    "a": {
                        "b": 4
                    },
                    "1": {
                        "2": {
                            "3": 100
                        }
                    },
                },
                {
                    "a": {
                        "b": [3, 4, 8, 10]
                    },
                    "1": {
                        "2": {
                            "3": lambda: random.choice([10, 100])
                        }
                    },
                },
                0.0,
                custom_explore_fn),
            {
                "a": {
                    "b": {3, 8}
                },
                "1": {
                    "2": {
                        "3": {80, 120}
                    }
                },
            })

        # Expect call count to be 100 because we call explore 100 times
        self.assertEqual(custom_explore_fn.call_count, 100)
Example #4
0
    def testPerturbationValues(self):
        def assertProduces(fn, values):
            random.seed(0)
            seen = set()
            for _ in range(100):
                seen.add(fn()["v"])
            self.assertEqual(seen, values)

        # Categorical case
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 8})
        assertProduces(
            lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4})
        assertProduces(
            lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {8, 10})
        assertProduces(
            lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
            {3, 4, 8, 10})
        assertProduces(
            lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x),
            {3, 4, 8, 10})

        # Continuous case
        assertProduces(
            lambda: explore({"v": 100}, {
                "v": lambda: random.choice([10, 100])
            }, 0.0, lambda x: x), {80, 120})
        assertProduces(
            lambda: explore({"v": 100.0}, {
                "v": lambda: random.choice([10, 100])
            }, 0.0, lambda x: x), {80.0, 120.0})
        assertProduces(
            lambda: explore({"v": 100.0}, {
                "v": lambda: random.choice([10, 100])
            }, 1.0, lambda x: x), {10.0, 100.0})

        def deep_add(seen, new_values):
            for k, new_value in new_values.items():
                if isinstance(new_value, dict):
                    if k not in seen:
                        seen[k] = {}
                    seen[k].update(deep_add(seen[k], new_value))
                else:
                    if k not in seen:
                        seen[k] = set()
                    seen[k].add(new_value)

            return seen

        def assertNestedProduces(fn, values):
            random.seed(0)
            seen = {}
            for _ in range(100):
                new_config = fn()
                seen = deep_add(seen, new_config)
            self.assertEqual(seen, values)

        # Nested mutation and spec
        assertNestedProduces(
            lambda: explore({
                "a": {
                    "b": 4
                },
                "1": {
                    "2": {
                        "3": 100
                    }
                },
            }, {
                "a": {
                    "b": [3, 4, 8, 10]
                },
                "1": {
                    "2": {
                        "3": lambda: random.choice([10, 100])
                    }
                },
            }, 0.0, lambda x: x), {
                "a": {
                    "b": {3, 8}
                },
                "1": {
                    "2": {
                        "3": {80, 120}
                    }
                },
            })

        custom_explore_fn = MagicMock(side_effect=lambda x: x)

        # Nested mutation and spec
        assertNestedProduces(
            lambda: explore({
                "a": {
                    "b": 4
                },
                "1": {
                    "2": {
                        "3": 100
                    }
                },
            }, {
                "a": {
                    "b": [3, 4, 8, 10]
                },
                "1": {
                    "2": {
                        "3": lambda: random.choice([10, 100])
                    }
                },
            }, 0.0, custom_explore_fn), {
                "a": {
                    "b": {3, 8}
                },
                "1": {
                    "2": {
                        "3": {80, 120}
                    }
                },
            })

        # Expect call count to be 100 because we call explore 100 times
        self.assertEqual(custom_explore_fn.call_count, 100)