示例#1
0
 def test_lookahead_edge_cases(self, sync_period, slow_step_size,
                               correct_result):
     """Checks special cases of the lookahed optimizer parameters."""
     # These edge cases are important to check since users might use them as
     # simple ways of disabling lookahead in experiments.
     optimizer = lookahead.lookahead(_test_optimizer(-1), sync_period,
                                     slow_step_size)
     final_params, _ = self.loop(optimizer,
                                 num_steps=2,
                                 params=self.synced_initial_params)
     chex.assert_tree_all_close(final_params.slow, correct_result)
示例#2
0
    def test_lookahead(self):
        """Tests the lookahead optimizer in an analytically tractable setting."""
        sync_period = 3
        optimizer = lookahead.lookahead(_test_optimizer(-0.5),
                                        sync_period=sync_period,
                                        slow_step_size=1 / 3)

        final_params, _ = self.loop(optimizer, 2 * sync_period,
                                    self.synced_initial_params)
        # x steps must be: 3 -> 2 -> 1 -> 2 (sync) -> 1 -> 0 -> 1 (sync).
        # Similarly for y (with sign flipped).
        correct_final_params = {'x': 1, 'y': -1}
        chex.assert_tree_all_close(final_params.slow, correct_final_params)
示例#3
0
    def test_lookahead_state_reset(self, reset_state):
        """Checks that lookahead resets the fast optimizer state correctly."""
        num_steps = sync_period = 3
        fast_optimizer = _test_optimizer(-0.5)
        optimizer = lookahead.lookahead(fast_optimizer,
                                        sync_period=sync_period,
                                        slow_step_size=0.5,
                                        reset_state=reset_state)

        _, opt_state = self.loop(optimizer, num_steps,
                                 self.synced_initial_params)
        fast_state = opt_state.fast_state
        if reset_state:
            correct_state = fast_optimizer.init(self.initial_params)
        else:
            _, correct_state = self.loop(fast_optimizer, num_steps,
                                         self.initial_params)

        chex.assert_tree_all_close(fast_state, correct_state)