Exemple #1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
class FlaxOptimizersEquivalenceTest(chex.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  @parameterized.named_parameters(
      ('sgd',
       alias.sgd(LR),
       optim.GradientDescent(LR)),
      ('momentum',
       alias.sgd(LR, momentum=0.9),
       optim.Momentum(LR, beta=0.9)),  # Different names.
      ('nesterov_momentum',
       alias.sgd(LR, momentum=0.9, nesterov=True),
       optim.Momentum(LR, beta=0.9, nesterov=True)),
      ('rmsprop',
       alias.rmsprop(LR),
       optim.RMSProp(LR)),
      ('centered_rmsprop',
       alias.rmsprop(LR, centered=True),
       optim.RMSProp(LR, centered=True)),
      ('adam',
       alias.adam(LR),
       optim.Adam(LR)),
      ('adam_w',
       alias.adamw(LR, weight_decay=1e-4),
       optim.Adam(LR, weight_decay=1e-4)),  # Different name.
      ('adagrad',
       alias.adagrad(LR, initial_accumulator_value=0.),  # Different default!
       optim.Adagrad(LR)),
      ('lamb',
       alias.lamb(LR),
       optim.LAMB(LR)),
  )
  def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

    # flax/optim
    flax_params = self.init_params
    flax_optimizer = flax_optimizer.create(flax_params)
    for _ in range(STEPS):
      flax_optimizer = flax_optimizer.apply_gradient(
          self.per_step_updates)
      flax_params = flax_optimizer.target

    # optax
    optax_params = self.init_params
    state = optax_optimizer.init(optax_params)
    for _ in range(STEPS):
      updates, state = optax_optimizer.update(
          self.per_step_updates, state, optax_params)
      optax_params = update.apply_updates(optax_params, updates)

    # Check equivalence.
    chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Exemple #3
0
    def test_flatten(self):
        def init_params():
            return (jnp.array([1., 2.]), jnp.array([3., 4.]))

        per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

        # First calculate new params without flattening
        optax_sgd_params = init_params()
        sgd = alias.sgd(1e-2, 0.0)
        state_sgd = sgd.init(optax_sgd_params)
        updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd)
        sgd_params_no_flatten = update.apply_updates(optax_sgd_params,
                                                     updates_sgd)

        # And now calculate new params with flattening
        optax_sgd_params = init_params()
        sgd = wrappers.flatten(sgd)
        state_sgd = sgd.init(optax_sgd_params)
        updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd)
        sgd_params_flatten = update.apply_updates(optax_sgd_params,
                                                  updates_sgd)

        # Test that both give the same result
        chex.assert_tree_all_close(sgd_params_no_flatten,
                                   sgd_params_flatten,
                                   atol=1e-7,
                                   rtol=1e-7)
Exemple #4
0
 def test_explicit_dtype(self, dtype):
     expected_dtype = jax.dtypes.canonicalize_dtype(
         dtype)  # None -> float32
     tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
     trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
     self.assertEqual(expected_dtype, trace_state.trace.dtype)
     tx = alias.adam(0.1, mu_dtype=dtype)
     adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
     self.assertEqual(expected_dtype, adam_state.mu.dtype)
     tx = alias.adamw(0.1, mu_dtype=dtype)
     adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
     self.assertEqual(expected_dtype, adam_state.mu.dtype)
Exemple #5
0
    def test_apply_if_finite_pmap(self):
        # Unlike in `test_apply_if_finite`:
        # * pmap is applied to the gradient computation and the optimisation;
        # * the NaNs are caused inside the function and do not come from the inputs.
        half = jnp.ones([1]) / 2.
        two = jnp.ones([1]) * 2.  # Causes a NaN in arctanh

        def fn(x):
            return jnp.arctanh(x) * hk.get_parameter(
                'p', [], init=hk.initializers.Constant(0.))

        fn = hk.without_apply_rng(hk.transform(fn))

        opt = wrappers.apply_if_finite(alias.sgd(1.), 2)

        def fn_update(params, opt_state, x):
            grads = jax.grad(fn.apply)(params, x)
            grads = jax.lax.psum(grads, axis_name='i')
            updates, new_opt_state = opt.update(grads, opt_state, params)
            new_params = update.apply_updates(params, updates)
            return new_params, new_opt_state

        fn_update = jax.pmap(fn_update, axis_name='i')

        params = fn.init(jax.random.PRNGKey(1905), half)
        opt_state = opt.init(params)
        params = jax.tree_map(lambda x: x[None], params)
        opt_state = jax.tree_map(lambda x: x[None], opt_state)
        # Do one successful param update
        params, opt_state = fn_update(params, opt_state, half)
        self.assertTrue(bool(opt_state.last_finite))
        # Check 2 rejected param updates
        for step in range(2):
            params, opt_state = fn_update(params, opt_state, two)
            self.assertFalse(bool(opt_state.last_finite))
            self.assertEqual(step + 1, int(opt_state.notfinite_count))
        # Next successful param update
        params, opt_state = fn_update(params, opt_state, half)
        self.assertTrue(bool(opt_state.last_finite))
        # Again 2 rejected param updates
        for step in range(2):
            params, opt_state = fn_update(params, opt_state, two)
            self.assertFalse(bool(opt_state.last_finite))
            self.assertEqual(step + 1, int(opt_state.notfinite_count))
        # Next param update with NaN is accepted since we reached maximum
        params, opt_state = fn_update(params, opt_state, two)
        self.assertEqual(5, int(opt_state.total_notfinite))
Exemple #6
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # optax sgd
        optax_sgd_params = self.init_params
        sgd = alias.sgd(LR, 0.0)
        state_sgd = sgd.init(optax_sgd_params)

        # optax sgd plus apply every
        optax_sgd_apply_every_params = self.init_params
        sgd_apply_every = combine.chain(
            transform.apply_every(k=k), transform.trace(decay=0,
                                                        nesterov=False),
            transform.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optax_sgd_apply_every_params)
        transform_fn = self.variant(sgd_apply_every.update)

        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optax_sgd_params = update.apply_updates(optax_sgd_params,
                                                    updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = transform_fn(
                self.per_step_updates, state_sgd_apply_every)
            optax_sgd_apply_every_params = update.apply_updates(
                optax_sgd_apply_every_params, updates_sgd_apply_every)

            # Every k steps, check equivalence.
            if i % k == k - 1:
                chex.assert_tree_all_close(optax_sgd_apply_every_params,
                                           optax_sgd_params,
                                           atol=1e-6,
                                           rtol=1e-5)
            # Otherwise, check update is zero.
            else:
                chex.assert_tree_all_close(updates_sgd_apply_every,
                                           zero_update,
                                           atol=0.0,
                                           rtol=0.0)
Exemple #7
0
 def test_multi_steps_every_k_schedule(self):
     # Test a non-trivial schedule which varies over time.
     ms_opt = wrappers.MultiSteps(
         alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3))
     opt_init, opt_update = ms_opt.gradient_transformation()
     params = dict(a=jnp.zeros([]))
     opt_state = opt_init(params)
     grad = dict(a=jnp.zeros([]))
     self.assertFalse(ms_opt.has_updated(opt_state))
     # First two steps have 1 mini-step per update.
     for _ in range(2):
         _, opt_state = opt_update(grad, opt_state, params)
         self.assertTrue(ms_opt.has_updated(opt_state))
     # Subsequently, mini-steps should have 3 mini-steps per update.
     for _ in range(5):
         for _ in range(2):
             _, opt_state = opt_update(grad, opt_state, params)
             self.assertFalse(ms_opt.has_updated(opt_state))
         _, opt_state = opt_update(grad, opt_state, params)
         self.assertTrue(ms_opt.has_updated(opt_state))
Exemple #8
0
    def test_labels_mismatch(self, use_extra_label, use_fn):
        # The labels from label_fn must be a subet of the keys for the tx.
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        label_tree = {'a': 0, 'b': [1, 0], 'c': 1}  # prefix of params

        if use_extra_label:
            label_tree['a'] = 3

        transforms = {
            0: alias.sgd(1.),
            1: alias.adam(1., b1=0., b2=0.),
            2: transform.trace(1.0)
        }
        init_fn, update_fn = combine.multi_transform(
            transforms, (lambda _: label_tree) if use_fn else label_tree)

        if use_extra_label:
            with self.assertRaises(ValueError):
                self.variant(init_fn)(params)
        else:
            state = self.variant(init_fn)(params)
            updates = jax.tree_map(lambda x: x / 10.0, params)
            self.variant(update_fn)(updates, state)
Exemple #9
0
def _build_sgd():
    return alias.sgd(1.)
Exemple #10
0
class AliasTest(chex.TestCase):
    @parameterized.parameters(
        ('sgd', lambda: alias.sgd(1e-2, 0.0)),
        ('adam', lambda: alias.adam(1e-1)),
        ('adamw', lambda: alias.adamw(1e-1)),
        ('lamb', lambda: alias.adamw(1e-1)),
        ('rmsprop', lambda: alias.rmsprop(1e-1)),
        ('rmsprop_momentum', lambda: alias.rmsprop(5e-2, momentum=0.9)),
        ('fromage', lambda: alias.fromage(1e-2)),
        ('adabelief', lambda: alias.adabelief(1e-1)),
        ('radam', lambda: alias.radam(1e-1)), ('sm3', lambda: alias.sm3(1.0)),
        ('yogi', lambda: alias.yogi(1.0)),
        ('dpsgd', lambda: alias.dpsgd(1e-2, 10.0, 0.001, 0)))
    def test_parabel(self, opt_name, opt):
        opt = opt()

        initial_params = jnp.array([-1.0, 10.0, 1.0])
        final_params = jnp.array([1.0, -1.0, 1.0])

        @jax.grad
        def get_updates(params):
            return jnp.sum((params - final_params)**2)

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd': updates = updates[None]
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(1000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

    @parameterized.parameters(
        ('sgd', lambda: alias.sgd(2e-3, 0.2)),
        ('adam', lambda: alias.adam(1e-1)),
        ('adamw', lambda: alias.adamw(1e-1)),
        ('lamb', lambda: alias.adamw(1e-1)),
        ('rmsprop', lambda: alias.rmsprop(5e-3)),
        ('rmsprop_momentum', lambda: alias.rmsprop(5e-3, momentum=0.9)),
        ('fromage', lambda: alias.fromage(5e-3)),
        ('adabelief', lambda: alias.adabelief(1e-1)),
        ('radam', lambda: alias.radam(1e-3)), ('sm3', lambda: alias.sm3(1.0)),
        ('yogi', lambda: alias.yogi(1.0)),
        ('dpsgd', lambda: alias.dpsgd(2e-3, 10., 0.001, 0, 0.2)))
    def test_rosenbrock(self, opt_name, opt):
        opt = opt()

        a = 1.0
        b = 100.0
        initial_params = jnp.array([0.0, 0.0])
        final_params = jnp.array([a, a**2])

        @jax.grad
        def get_updates(params):
            return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd': updates = updates[None]
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Exemple #11
0
class AliasTest(chex.TestCase):
    def setUp(self):
        super(AliasTest, self).setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR, .9, 0.1), optimizers.rmsprop(
            LR, .9, 0.1), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(1e-2, 0.0)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(1e-1)),
        ('fromage', transform.scale_by_fromage(-1e-2)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_parabel(self, opt):
        initial_params = jnp.array([-1.0, 10.0, 1.0])
        final_params = jnp.array([1.0, -1.0, 1.0])

        @jax.grad
        def get_updates(params):
            return jnp.sum((params - final_params)**2)

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(1000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=1e-2, atol=1e-2)

    @parameterized.named_parameters(
        ('sgd', alias.sgd(2e-3, 0.2)),
        ('adam', alias.adam(1e-1)),
        ('adamw', alias.adamw(1e-1)),
        ('lamb', alias.adamw(1e-1)),
        ('rmsprop', alias.rmsprop(5e-3)),
        ('fromage', transform.scale_by_fromage(-5e-3)),
        ('adabelief', alias.adabelief(1e-1)),
    )
    def test_rosenbrock(self, opt):
        a = 1.0
        b = 100.0
        initial_params = jnp.array([0.0, 0.0])
        final_params = jnp.array([a, a**2])

        @jax.grad
        def get_updates(params):
            return (a - params[0])**2 + b * (params[1] - params[0]**2)**2

        @jax.jit
        def step(params, state):
            updates, state = opt.update(get_updates(params), state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Exemple #12
0
class FlaxOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([0., 0.3, 500.,
                                            5.]), jnp.array([300., 3.]))

    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR), optim.GradientDescent(LR)),
        ('momentum', alias.sgd(LR, momentum=0.9), optim.Momentum(
            LR, beta=0.9)),  # Different names.
        ('nesterov_momentum', alias.sgd(LR, momentum=0.9, nesterov=True),
         optim.Momentum(LR, beta=0.9, nesterov=True)),
        ('rmsprop', alias.rmsprop(LR), optim.RMSProp(LR)),
        ('centered_rmsprop', alias.rmsprop(
            LR, centered=True), optim.RMSProp(LR, centered=True)),
        ('adam', alias.adam(LR), optim.Adam(LR)),
        ('adam_w', alias.adamw(LR, weight_decay=1e-4),
         optim.Adam(LR, weight_decay=1e-4)),  # Different name.
        (
            'adagrad',
            alias.adagrad(LR,
                          initial_accumulator_value=0.),  # Different default!
            optim.Adagrad(LR)),
        ('lamb', alias.lamb(LR), optim.LAMB(LR)),
        ('lars',
         alias.lars(LR,
                    weight_decay=.5,
                    trust_coefficient=0.003,
                    momentum=0.9,
                    eps=1e-3),
         optim.LARS(
             LR, weight_decay=.5, trust_coefficient=0.003, beta=0.9,
             eps=1e-3)),
        ('adafactor',
         alias.adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2),
         optim.Adafactor(learning_rate=LR / 10.,
                         factored=True,
                         multiply_by_parameter_scale=True,
                         clipping_threshold=1.0,
                         decay_rate=0.8,
                         min_dim_size_to_factor=2)),
    )
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=2e-4)
Exemple #13
0
class AliasTest(chex.TestCase):
    @parameterized.product(
        (
            dict(opt_name='sgd', opt=lambda: alias.sgd(1e-3, 0.9)),
            dict(opt_name='adafactor', opt=lambda: alias.adafactor(5e-3)),
            dict(opt_name='adagrad', opt=lambda: alias.adagrad(1.0)),
            dict(opt_name='adam', opt=lambda: alias.adam(1e-1)),
            dict(opt_name='adamw', opt=lambda: alias.adamw(1e-1)),
            dict(opt_name='lars', opt=lambda: alias.lars(1.0)),
            dict(opt_name='lamb', opt=lambda: alias.lamb(1e-3)),
            dict(opt_name='noisy_sgd',
                 opt=lambda: alias.noisy_sgd(1e-3, eta=1e-4)),
            dict(opt_name='rmsprop', opt=lambda: alias.rmsprop(5e-3)),
            dict(opt_name='rmsprop_momentum',
                 opt=lambda: alias.rmsprop(5e-3, momentum=0.9)),
            dict(opt_name='fromage', opt=lambda: alias.fromage(5e-3)),
            dict(opt_name='adabelief', opt=lambda: alias.adabelief(1e-2)),
            dict(opt_name='radam', opt=lambda: alias.radam(5e-3)),
            dict(opt_name='sm3', opt=lambda: alias.sm3(1.0)),
            dict(opt_name='yogi', opt=lambda: alias.yogi(1e-1)),
            dict(opt_name='dpsgd',
                 opt=lambda: alias.dpsgd(1e-3, 10.0, 0.001, 0, 0.2)),
        ),
        target=(_setup_parabola, _setup_rosenbrock),
        dtype=(jnp.float32, jnp.complex64),
    )
    def test_optimization(self, opt_name, opt, target, dtype):
        if (opt_name in ('fromage', 'noisy_sgd', 'sm3')
                and jnp.iscomplexobj(dtype)):
            raise absltest.SkipTest(
                f'{opt_name} does not support complex parameters.')

        opt = opt()
        initial_params, final_params, get_updates = target(dtype)

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd':
                updates = updates[None]
            # Complex gradients need to be conjugated before being added to parameters
            # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
            updates = jax.tree_map(lambda x: x.conj(), updates)
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)

    @parameterized.named_parameters([
        ('float32', 'float32'),
        ('bfloat16', 'bfloat16'),
        ('complex64', 'complex64'),
        ('None', None),
    ])
    def test_explicit_dtype(self, dtype):
        expected_dtype = jax.dtypes.canonicalize_dtype(
            dtype)  # None -> float32
        tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
        trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, trace_state.trace.dtype)
        tx = alias.adam(0.1, mu_dtype=dtype)
        adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, adam_state.mu.dtype)
        tx = alias.adamw(0.1, mu_dtype=dtype)
        adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
        self.assertEqual(expected_dtype, adam_state.mu.dtype)
Exemple #14
0
 def test_empty(self, container):
     init_fn, update_fn = combine.multi_transform({0: alias.sgd(1.)},
                                                  lambda _: 0)
     updates, _ = update_fn(container(), init_fn(container()))
     self.assertEqual(updates, container())