コード例 #1
0
 def test_lookahead_edge_cases(self, sync_period, slow_step_size,
                               correct_result):
   """Checks special cases of the lookahed optimizer parameters."""
   # These edge cases are important to check since users might use them as
   # simple ways of disabling lookahead in experiments.
   optimizer = wrappers.lookahead(
       _test_optimizer(-1), sync_period, slow_step_size)
   final_params, _ = self.loop(
       optimizer, num_steps=2, params=self.synced_initial_params)
   chex.assert_tree_all_close(final_params.slow, correct_result)
コード例 #2
0
  def test_lookahead(self):
    """Tests the lookahead optimizer in an analytically tractable setting."""
    sync_period = 3
    optimizer = wrappers.lookahead(
        _test_optimizer(-0.5), sync_period=sync_period, slow_step_size=1 / 3)

    final_params, _ = self.loop(optimizer, 2 * sync_period,
                                self.synced_initial_params)
    # x steps must be: 3 -> 2 -> 1 -> 2 (sync) -> 1 -> 0 -> 1 (sync).
    # Similarly for y (with sign flipped).
    correct_final_params = {'x': 1, 'y': -1}
    chex.assert_tree_all_close(final_params.slow, correct_final_params)
コード例 #3
0
ファイル: wrappers_test.py プロジェクト: stjordanis/optax
    def test_lookahead_state_reset(self, reset_state):
        """Checks that lookahead resets the fast optimizer state correctly."""
        num_steps = sync_period = 3
        fast_optimizer = test_optimizer(-0.5)
        optimizer = wrappers.lookahead(fast_optimizer,
                                       sync_period=sync_period,
                                       slow_step_size=0.5,
                                       reset_state=reset_state)

        _, opt_state = self.loop(optimizer, num_steps,
                                 self.synced_initial_params)
        fast_state = opt_state.fast_state
        if reset_state:
            correct_state = fast_optimizer.init(self.initial_params)
        else:
            _, correct_state = self.loop(fast_optimizer, num_steps,
                                         self.initial_params)

        chex.assert_tree_all_close(fast_state, correct_state)