Example #1
0
  def test_chain(self):
    transformations = [
        transform.scale_by_adam(),
        transform.trace(decay=0, nesterov=False),
        transform.scale(-LR)]

    # Apply updates with chain.
    chain_params = self.init_params
    chained_transforms = combine.chain(*transformations)
    state = chained_transforms.init(chain_params)

    @self.variant
    def update_fn(updates, state):
      return chained_transforms.update(updates, state)

    for _ in range(STEPS):
      updates, state = update_fn(self.per_step_updates, state)
      chain_params = update.apply_updates(chain_params, updates)

    # Manually apply sequence of transformations.
    manual_params = self.init_params
    states = [t.init(manual_params) for t in transformations]
    for _ in range(STEPS):
      updates = self.per_step_updates
      new_states = []
      for t, s in zip(transformations, states):
        updates, state = t.update(updates, s)
        new_states.append(state)
      manual_params = update.apply_updates(manual_params, updates)
      states = new_states

    # Check equivalence.
    chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
Example #2
0
    def test_flatten(self):
        def init_params():
            return (jnp.array([1., 2.]), jnp.array([3., 4.]))

        per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

        # First calculate new params without flattening
        optax_sgd_params = init_params()
        sgd = alias.sgd(1e-2, 0.0)
        state_sgd = sgd.init(optax_sgd_params)
        updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd)
        sgd_params_no_flatten = update.apply_updates(optax_sgd_params,
                                                     updates_sgd)

        # And now calculate new params with flattening
        optax_sgd_params = init_params()
        sgd = wrappers.flatten(sgd)
        state_sgd = sgd.init(optax_sgd_params)
        updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd)
        sgd_params_flatten = update.apply_updates(optax_sgd_params,
                                                  updates_sgd)

        # Test that both give the same result
        chex.assert_tree_all_close(sgd_params_no_flatten,
                                   sgd_params_flatten,
                                   atol=1e-7,
                                   rtol=1e-7)
Example #3
0
    def test_keep_params_nonnegative(self):
        grads = (jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.]),
                 jnp.array([500., -500., 0.]))

        params = (jnp.array([-1., -1.,
                             -1.]), jnp.array([1., 1.,
                                               1.]), jnp.array([0., 0., 0.]))

        # vanilla sgd
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR))
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(
            new_params, (jnp.array([-6., 4., -1.]), jnp.array(
                [-4., 6., 1.]), jnp.array([-5., 5., 0.])))

        # sgd with keeping parameters non-negative
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR),
                            constrain.keep_params_nonnegative())
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(new_params, (jnp.array(
            [0., 4., 0.]), jnp.array([0., 6., 1.]), jnp.array([0., 5., 0.])))
Example #4
0
    def test_apply_if_finite(self, opt_builder):
        one = jnp.ones([])
        nan = jnp.array(jnp.nan)

        def fn(x):
            return x * hk.get_parameter('p', [],
                                        init=hk.initializers.Constant(0.))

        fn = hk.without_apply_rng(hk.transform(fn))
        params = fn.init(jax.random.PRNGKey(1905), one)
        opt = wrappers.apply_if_finite(opt_builder(), 2)
        state = opt.init(params)
        grads_fn = jax.grad(self.variant(fn.apply))
        # Do one successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        # We know exactly what should be the value of params since we are
        # effectively using sgd in all cases.
        self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Check 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Again 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next param update with NaN is accepted since we reached maximum
        grads = grads_fn(params, nan)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0])))
        self.assertEqual(5, int(state.total_notfinite))
Example #5
0
  def test_update_requires_params(self):
    weight_decay = 0.1
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, input_updates, params)

    init_fn, update_fn = wrappers.masked(
        transform.additive_weight_decay(weight_decay), mask)
    update_fn = self.variant(update_fn)

    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    params = update.apply_updates(params, updates)

    # Test repeated application
    new_correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, correct_updates, params)
    updates, state = update_fn(correct_updates, state, params)
    chex.assert_tree_all_close(updates, new_correct_updates)
Example #6
0
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Example #7
0
 def update_fn(updates, state, params):
     # The test optimizer does not use the parameters, but we check that they
     # have been passed correctly.
     chex.assert_trees_all_equal_shapes(updates, params)
     aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
     updates = jax.tree_map(lambda u: step_size * u, updates)
     return updates, TestOptimizerState(aggregate_grads, is_reset=False)
Example #8
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # optax sgd
        optax_sgd_params = self.init_params
        sgd = alias.sgd(LR, 0.0)
        state_sgd = sgd.init(optax_sgd_params)

        # optax sgd plus apply every
        optax_sgd_apply_every_params = self.init_params
        sgd_apply_every = combine.chain(
            transform.apply_every(k=k), transform.trace(decay=0,
                                                        nesterov=False),
            transform.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optax_sgd_apply_every_params)
        transform_fn = self.variant(sgd_apply_every.update)

        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optax_sgd_params = update.apply_updates(optax_sgd_params,
                                                    updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = transform_fn(
                self.per_step_updates, state_sgd_apply_every)
            optax_sgd_apply_every_params = update.apply_updates(
                optax_sgd_apply_every_params, updates_sgd_apply_every)

            # Every k steps, check equivalence.
            if i % k == k - 1:
                chex.assert_tree_all_close(optax_sgd_apply_every_params,
                                           optax_sgd_params,
                                           atol=1e-6,
                                           rtol=1e-5)
            # Otherwise, check update is zero.
            else:
                chex.assert_tree_all_close(updates_sgd_apply_every,
                                           zero_update,
                                           atol=0.0,
                                           rtol=0.0)
Example #9
0
 def do_update(loss_fun, optimizer, params, opt_state):
     loss, grads = jax.value_and_grad(loss_fun)(params)
     # Complex gradients need to be conjugated before being added to parameters
     grads = jax.tree_map(lambda x: x.conj(), grads)
     updates, opt_state = self.variant(optimizer.update)(grads,
                                                         opt_state,
                                                         params)
     params = update.apply_updates(params, updates)
     return loss, grads, params, opt_state
Example #10
0
 def step(params, state):
     updates = get_updates(params)
     if opt_name == 'dpsgd':
         updates = updates[None]
     # Complex gradients need to be conjugated before being added to parameters
     # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
     updates = jax.tree_map(lambda x: x.conj(), updates)
     updates, state = opt.update(updates, state, params)
     params = update.apply_updates(params, updates)
     return params, state
Example #11
0
    def loop(self, optimizer, num_steps, params):
        """Performs a given number of optimizer steps."""
        init_fn, update_fn = optimizer
        # Use the chex variant to check various function versions (jit, pmap, etc).
        step = self.variant(update_fn)
        opt_state = self.variant(init_fn)(params)
        for _ in range(num_steps):
            updates, opt_state = step(self.grads, opt_state, params)
            params = update.apply_updates(params, updates)

        return params, opt_state
Example #12
0
    def test_float32_input_outputs(self, transform_constr, transform_kwargs):
        initial_params = (jnp.array([1., 2.], dtype=jnp.float32),
                          jnp.array([3., 4.], dtype=jnp.float32))
        updates = (jnp.array([10., 21.], dtype=jnp.float32),
                   jnp.array([33., 42.], dtype=jnp.float32))
        scaler = transform_constr(**transform_kwargs)
        init_fn = self.variant(scaler.init)
        update_fn = self.variant(scaler.update)

        initial_state = init_fn(initial_params)
        updates, new_state = update_fn(updates,
                                       initial_state,
                                       params=initial_params)
        new_params = update.apply_updates(initial_params, updates)

        self._assert_dtype_equals(initial_state, new_state)
        self._assert_dtype_equals(initial_params, new_params)
Example #13
0
    def test_multi_steps(self):
        batch_size = 32
        x_size = 7
        # Parameters should be updated only every `k_steps` optimisation steps.
        k_steps = 4
        data = jnp.ones([batch_size, x_size])

        def get_loss(x):
            loss = jnp.sum(hk.Linear(10)(x)**2)
            return loss

        loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss))
        params = loss_init(jax.random.PRNGKey(1915), data)

        ms_opt = wrappers.MultiSteps(alias.adam(1e-4), k_steps)
        opt_init, opt_update = ms_opt.gradient_transformation()

        # Put the training in one function, to check that the update is indeed
        # jittable.
        def train_step(data, opt_state, params):
            grad = jax.grad(loss_apply)(params, data)
            updates, opt_state = opt_update(grad, opt_state, params)
            return updates, opt_state

        opt_state = opt_init(params)

        prev_loss = loss_apply(params, data)
        for idx in range(5 * k_steps):
            updates, opt_state = self.variant(train_step)(data, opt_state,
                                                          params)
            new_params = update.apply_updates(params, updates)
            new_loss = loss_apply(new_params, data)
            if idx % k_steps < k_steps - 1:
                # The parameters should not have changed and the loss should be
                # constant.
                jax.tree_multimap(np.testing.assert_array_equal, new_params,
                                  params)
                np.testing.assert_equal(new_loss, prev_loss)
                self.assertFalse(ms_opt.has_updated(opt_state))
            else:
                # This is a step where parameters should actually have been updated, and
                # the loss should accordingly go down.
                np.testing.assert_array_less(new_loss, prev_loss)
                prev_loss = new_loss
                self.assertTrue(ms_opt.has_updated(opt_state))
            params = new_params
Example #14
0
    def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

        # flax/optim
        flax_params = self.init_params
        flax_optimizer = flax_optimizer.create(flax_params)
        for _ in range(STEPS):
            flax_optimizer = flax_optimizer.apply_gradient(
                self.per_step_updates)
            flax_params = flax_optimizer.target

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)
        for _ in range(STEPS):
            updates, state = optax_optimizer.update(self.per_step_updates,
                                                    state, optax_params)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(flax_params, optax_params, rtol=1e-4)
Example #15
0
 def fn_update(params, opt_state, x):
     grads = jax.grad(fn.apply)(params, x)
     grads = jax.lax.psum(grads, axis_name='i')
     updates, new_opt_state = opt.update(grads, opt_state, params)
     new_params = update.apply_updates(params, updates)
     return new_params, new_opt_state
Example #16
0
 def step(params, state):
     updates = get_updates(params)
     if opt_name == 'dpsgd': updates = updates[None]
     updates, state = opt.update(updates, state, params)
     params = update.apply_updates(params, updates)
     return params, state
Example #17
0
 def step(params, state):
     updates, state = opt.update(get_updates(params), state, params)
     params = update.apply_updates(params, updates)
     return params, state
Example #18
0
 def update_fn(updates, state, params=None):
     del params  # unused by the test optimizer
     aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
     updates = jax.tree_map(lambda u: step_size * u, updates)
     return updates, TestOptimizerState(aggregate_grads, is_reset=False)