def test_lookahead_edge_cases(self, sync_period, slow_step_size, correct_result): """Checks special cases of the lookahed optimizer parameters.""" # These edge cases are important to check since users might use them as # simple ways of disabling lookahead in experiments. optimizer = lookahead.lookahead(_test_optimizer(-1), sync_period, slow_step_size) final_params, _ = self.loop(optimizer, num_steps=2, params=self.synced_initial_params) chex.assert_tree_all_close(final_params.slow, correct_result)
def test_lookahead(self): """Tests the lookahead optimizer in an analytically tractable setting.""" sync_period = 3 optimizer = lookahead.lookahead(_test_optimizer(-0.5), sync_period=sync_period, slow_step_size=1 / 3) final_params, _ = self.loop(optimizer, 2 * sync_period, self.synced_initial_params) # x steps must be: 3 -> 2 -> 1 -> 2 (sync) -> 1 -> 0 -> 1 (sync). # Similarly for y (with sign flipped). correct_final_params = {'x': 1, 'y': -1} chex.assert_tree_all_close(final_params.slow, correct_final_params)
def test_lookahead_state_reset(self, reset_state): """Checks that lookahead resets the fast optimizer state correctly.""" num_steps = sync_period = 3 fast_optimizer = _test_optimizer(-0.5) optimizer = lookahead.lookahead(fast_optimizer, sync_period=sync_period, slow_step_size=0.5, reset_state=reset_state) _, opt_state = self.loop(optimizer, num_steps, self.synced_initial_params) fast_state = opt_state.fast_state if reset_state: correct_state = fast_optimizer.init(self.initial_params) else: _, correct_state = self.loop(fast_optimizer, num_steps, self.initial_params) chex.assert_tree_all_close(fast_state, correct_state)