Exemple #1
0
 def test_momentum_with_weight_norm(self):
   params = onp.ones((2, 2)) * 2.
   optimizer_def = optim.WeightNorm(optim.Momentum(0.1))
   state = optimizer_def.init_state(params)
   self.assertEqual(jax.tree_map(onp.shape, state), optim.OptimizerState(
       step=(),
       param_states=_WeightNormParamState(
           direction_state=_MomentumParamState(momentum=(2, 2)),
           scale_state=_MomentumParamState(momentum=(1, 2)),
           mult=(1, 2)
       )
   ))
   grads = onp.ones((2, 2))
   new_params, new_state = optimizer_def.apply_gradient(
       optimizer_def.hyper_params, params, state, grads)
   onp.testing.assert_allclose(new_params, onp.full_like(params, 1.9))
   onp.testing.assert_allclose(new_state.param_states.mult, 1.9 * 2 ** 0.5)
def create_optimizer(config, model, learning_rate, train_size, sampler_rng):
  """Create optimizer definition based on config flags."""
  if config.optimizer == 'adam':
    optimizer_def = optim.Adam(
        learning_rate=learning_rate, beta1=config.momentum)
  elif config.optimizer == 'momentum':
    optimizer_def = optim.Momentum(
        learning_rate=learning_rate, beta=config.momentum)
  elif config.optimizer == 'sym_euler':
    optimizer_def = sym_euler_sgmcmc.SymEulerSGMCMC(
        train_size,
        sampler_rng,
        learning_rate=learning_rate,
        beta=config.momentum,
        temperature=config.base_temp,
        step_size_factor=1.)
  else:
    raise ValueError('Invalid value %s for config.optimizer.' %
                     config.optimizer)

  if config.weight_norm == 'none':
    pass
  elif config.weight_norm == 'learned':
    optimizer_def = optim.WeightNorm(optimizer_def)
  elif config.weight_norm in ['fixed', 'ws_sqrt', 'learned_b', 'ws']:
    # Applied in layers directly.
    pass
  else:
    raise ValueError('Invalid value %s for config.weight_norm.' %
                     config.weight_norm)

  optimizer = optimizer_def.create(model)

  if not config.debug_run:
    optimizer = optimizer.replicate()
  return optimizer