Ejemplo n.º 1
0
class AliasTest(chex.TestCase):

  def setUp(self):
    super(AliasTest, self).setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @chex.all_variants()
  @parameterized.named_parameters(
      ('sgd', alias.sgd(LR, 0.0),
       optimizers.sgd(LR), 1e-5),
      ('adam', alias.adam(LR, 0.9, 0.999, 1e-8),
       optimizers.adam(LR, 0.9, 0.999), 1e-4),
      ('rmsprop', alias.rmsprop(LR, .9, 0.1),
       optimizers.rmsprop(LR, .9, 0.1), 1e-5),
      ('adagrad', alias.adagrad(LR, 0., 0.,),
       optimizers.adagrad(LR, 0.), 1e-5),
  )
  def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol):

    # experimental/optimizers.py
    jax_params = self.init_params
    opt_init, opt_update, get_params = jax_optimizer
    state = opt_init(jax_params)
    for i in range(STEPS):
      state = opt_update(i, self.per_step_updates, state)
      jax_params = get_params(state)

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)

    @self.variant
    def step(updates, state):
      return optax_optimizer.update(updates, state)

    for _ in range(STEPS):
      updates, state = step(self.per_step_updates, state)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)

  @parameterized.named_parameters(
      ('sgd', lambda: alias.sgd(1e-2, 0.0)),
      ('adam', lambda: alias.adam(1e-1)),
      ('adamw', lambda: alias.adamw(1e-1)),
      ('lamb', lambda: alias.adamw(1e-1)),
      ('rmsprop', lambda: alias.rmsprop(1e-1)),
      ('fromage', lambda: alias.fromage(1e-2)),
      ('adabelief', lambda: alias.adabelief(1e-1)),
      ('radam', lambda: alias.radam(1e-1)),
      ('yogi', lambda: alias.yogi(1.0)),
  )
  def test_parabel(self, opt):
    opt = opt()

    initial_params = jnp.array([-1.0, 10.0, 1.0])
    final_params = jnp.array([1.0, -1.0, 1.0])

    @jax.grad
    def get_updates(params):
      return jnp.sum((params - final_params)**2)

    @jax.jit
    def step(params, state):
      updates, state = opt.update(get_updates(params), state, params)
      params = update.apply_updates(params, updates)
      return params, state

    params = initial_params
    state = opt.init(params)
    for _ in range(1000):
      params, state = step(params, state)

    chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

  @parameterized.named_parameters(
      ('sgd', lambda: alias.sgd(2e-3, 0.2)),
      ('adam', lambda: alias.adam(1e-1)),
      ('adamw', lambda: alias.adamw(1e-1)),
      ('lamb', lambda: alias.adamw(1e-1)),
      ('rmsprop', lambda: alias.rmsprop(5e-3)),
      ('fromage', lambda: alias.fromage(5e-3)),
      ('adabelief', lambda: alias.adabelief(1e-1)),
      ('radam', lambda: alias.radam(1e-3)),
      ('yogi', lambda: alias.yogi(1.0)),
  )
  def test_rosenbrock(self, opt):
    opt = opt()

    a = 1.0
    b = 100.0
    initial_params = jnp.array([0.0, 0.0])
    final_params = jnp.array([a, a**2])

    @jax.grad
    def get_updates(params):
      return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

    @jax.jit
    def step(params, state):
      updates, state = opt.update(get_updates(params), state, params)
      params = update.apply_updates(params, updates)
      return params, state

    params = initial_params
    state = opt.init(params)
    for _ in range(10000):
      params, state = step(params, state)

    chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Ejemplo n.º 2
0
class AliasTest(chex.TestCase):
    @parameterized.parameters(
        ('sgd', lambda: alias.sgd(1e-2, 0.0)),
        ('adam', lambda: alias.adam(1e-1)),
        ('adamw', lambda: alias.adamw(1e-1)),
        ('lamb', lambda: alias.adamw(1e-1)),
        ('rmsprop', lambda: alias.rmsprop(1e-1)),
        ('rmsprop_momentum', lambda: alias.rmsprop(5e-2, momentum=0.9)),
        ('fromage', lambda: alias.fromage(1e-2)),
        ('adabelief', lambda: alias.adabelief(1e-1)),
        ('radam', lambda: alias.radam(1e-1)), ('sm3', lambda: alias.sm3(1.0)),
        ('yogi', lambda: alias.yogi(1.0)),
        ('dpsgd', lambda: alias.dpsgd(1e-2, 10.0, 0.001, 0)))
    def test_parabel(self, opt_name, opt):
        opt = opt()

        initial_params = jnp.array([-1.0, 10.0, 1.0])
        final_params = jnp.array([1.0, -1.0, 1.0])

        @jax.grad
        def get_updates(params):
            return jnp.sum((params - final_params)**2)

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd': updates = updates[None]
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(1000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

    @parameterized.parameters(
        ('sgd', lambda: alias.sgd(2e-3, 0.2)),
        ('adam', lambda: alias.adam(1e-1)),
        ('adamw', lambda: alias.adamw(1e-1)),
        ('lamb', lambda: alias.adamw(1e-1)),
        ('rmsprop', lambda: alias.rmsprop(5e-3)),
        ('rmsprop_momentum', lambda: alias.rmsprop(5e-3, momentum=0.9)),
        ('fromage', lambda: alias.fromage(5e-3)),
        ('adabelief', lambda: alias.adabelief(1e-1)),
        ('radam', lambda: alias.radam(1e-3)), ('sm3', lambda: alias.sm3(1.0)),
        ('yogi', lambda: alias.yogi(1.0)),
        ('dpsgd', lambda: alias.dpsgd(2e-3, 10., 0.001, 0, 0.2)))
    def test_rosenbrock(self, opt_name, opt):
        opt = opt()

        a = 1.0
        b = 100.0
        initial_params = jnp.array([0.0, 0.0])
        final_params = jnp.array([a, a**2])

        @jax.grad
        def get_updates(params):
            return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd': updates = updates[None]
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Ejemplo n.º 3
0
class AliasTest(chex.TestCase):
    @parameterized.product(
        (
            dict(opt_name='sgd', opt=lambda: alias.sgd(1e-3, 0.9)),
            dict(opt_name='adafactor', opt=lambda: alias.adafactor(5e-3)),
            dict(opt_name='adagrad', opt=lambda: alias.adagrad(1.0)),
            dict(opt_name='adam', opt=lambda: alias.adam(1e-1)),
            dict(opt_name='adamw', opt=lambda: alias.adamw(1e-1)),
            dict(opt_name='lars', opt=lambda: alias.lars(1.0)),
            dict(opt_name='lamb', opt=lambda: alias.lamb(1e-3)),
            dict(opt_name='noisy_sgd',
                 opt=lambda: alias.noisy_sgd(1e-3, eta=1e-4)),
            dict(opt_name='rmsprop', opt=lambda: alias.rmsprop(5e-3)),
            dict(opt_name='rmsprop_momentum',
                 opt=lambda: alias.rmsprop(5e-3, momentum=0.9)),
            dict(opt_name='fromage', opt=lambda: alias.fromage(5e-3)),
            dict(opt_name='adabelief', opt=lambda: alias.adabelief(1e-2)),
            dict(opt_name='radam', opt=lambda: alias.radam(5e-3)),
            dict(opt_name='sm3', opt=lambda: alias.sm3(1.0)),
            dict(opt_name='yogi', opt=lambda: alias.yogi(1e-1)),
            dict(opt_name='dpsgd',
                 opt=lambda: alias.dpsgd(1e-3, 10.0, 0.001, 0, 0.2)),
        ),
        target=(_setup_parabola, _setup_rosenbrock),
        dtype=(jnp.float32, jnp.complex64),
    )
    def test_optimization(self, opt_name, opt, target, dtype):
        if (opt_name in ('fromage', 'noisy_sgd', 'sm3')
                and jnp.iscomplexobj(dtype)):
            raise absltest.SkipTest(
                f'{opt_name} does not support complex parameters.')

        opt = opt()
        initial_params, final_params, get_updates = target(dtype)

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd':
                updates = updates[None]
            # Complex gradients need to be conjugated before being added to parameters
            # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
            updates = jax.tree_map(lambda x: x.conj(), updates)
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)

    @parameterized.named_parameters([
        ('float32', 'float32'),
        ('bfloat16', 'bfloat16'),
        ('complex64', 'complex64'),
        ('None', None),
    ])
    def test_explicit_dtype(self, dtype):
        expected_dtype = jax.dtypes.canonicalize_dtype(
            dtype)  # None -> float32
        tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
        trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, trace_state.trace.dtype)
        tx = alias.adam(0.1, mu_dtype=dtype)
        adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, adam_state.mu.dtype)
        tx = alias.adamw(0.1, mu_dtype=dtype)
        adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, adam_state.mu.dtype)