예제 #1
0
 def test_set_to_zero_returns_tree_of_correct_zero_arrays(self):
     """Tests that zero transform returns a tree of zeros of correct shape."""
     grads = ({'a': np.ones((3, 4)), 'b': 1.}, np.ones((1, 2, 3)))
     updates, _ = self.variant(base.set_to_zero().update)(grads,
                                                          base.EmptyState())
     correct_zeros = ({'a': np.zeros((3, 4)), 'b': 0.}, np.zeros((1, 2, 3)))
     chex.assert_trees_all_close(updates, correct_zeros, rtol=0)
예제 #2
0
 def test_output_modality_1(self):
     decays = [0.19, 0.75, 1.0]
     scales = [0.9, 0.5, 1.0]
     decay_distribution = [0.34, 0.34, 0.32]
     ks_opt = transform_chain(['precondition_by_layered_adaptive_rms'],
                              [{
                                  'decays': decays,
                                  'scales': scales,
                                  'decay_distribution': decay_distribution,
                                  'eps_root': 0.0
                              }],
                              learning_rate=1.0)
     scales = jnp.array([0.9, 0.5, 1.0, 1.0])
     betas = jnp.array([0.19, 0.75, 1.0, 1.0])
     one_minus_betas = jnp.array([0.81, 0.25, 1.0, 1.0])
     params = {'w': jnp.ones((4, ))}
     opt_state = ks_opt.init(params)
     # step 1
     grads = {'w': 2 * jnp.ones((4, ))}
     true_nu = one_minus_betas * (grads['w']**2)
     true_updates = {
         'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu)
     }
     opt_updates, opt_state = ks_opt.update(grads, opt_state)
     chex.assert_trees_all_close(true_updates, opt_updates)
     params = optax.apply_updates(params, opt_updates)
     # step2
     grads = {'w': jnp.ones((4, ))}
     true_nu = one_minus_betas * (grads['w']**2) + betas * true_nu
     true_updates = {
         'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu)
     }
     opt_updates, opt_state = ks_opt.update(grads, opt_state)
     chex.assert_trees_all_close(true_updates, opt_updates)
예제 #3
0
    def test_correctness(self):
        """Testing correctness via optax.adam."""
        def amsgrad():
            adam = optax.scale_by_adam()

            def init_fn(params):
                return adam.init(params)

            def update_fn(updates, state, params=None):
                prev_nu = state.nu
                _, state = adam.update(updates, state, params)
                curr_nu = state.nu
                nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu)
                updates = jax.tree_multimap(
                    lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu,
                    nu_hat)

                return updates, optax.ScaleByAdamState(count=state.count,
                                                       mu=state.mu,
                                                       nu=nu_hat)

            return optax.GradientTransformation(init_fn, update_fn)

        true_amsgrad = amsgrad()
        ks_amsgrad = transform_chain(['scale_by_amsgrad'])

        targets = _optimizer_loop(true_amsgrad)
        results = _optimizer_loop(ks_amsgrad)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
예제 #4
0
    def test_correctness(self):
        """Testing correctness via independent implementation."""
        def ema(decay, debias=True):
            def init_fn(params):
                del params
                return {'w': jnp.zeros((2, )), 'count': 0}

            def update_fn(updates, state, params=None):
                del params
                state['count'] += 1
                state['w'] = ((1 - decay) * updates['w'] + decay * state['w'])
                if debias:
                    update = {'w': state['w'] / (1 - decay**state['count'])}
                else:
                    update = {'w': state['w']}
                return update, state

            return optax.GradientTransformation(init_fn, update_fn)

        decay = 0.7
        learning_rate = 0.01
        true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate))
        ks_ema = transform_chain(['first_moment_ema'], [{
            'decay': decay,
            'debias': True,
        }],
                                 learning_rate=learning_rate)
        targets = _optimizer_loop(true_ema)
        results = _optimizer_loop(ks_ema)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
예제 #5
0
    def test_common_clipping_norm(self, l2_norm_clip):
        l2_norms_threshold = jax.tree_map(lambda p: l2_norm_clip, self.params)
        dp_agg = optimizers.dp_aggregate(l2_norms_threshold=l2_norms_threshold,
                                         base_sensitivity=1.,
                                         noise_multiplier=0.,
                                         init_rng=jax.random.PRNGKey(42))
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)

        # Shape of the three arrays below is (self.batch_size, )
        norms = [
            jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
            for g in jax.tree_leaves(self.per_eg_grads)
        ]
        divisors = [jnp.maximum(norm / l2_norm_clip, 1.) for norm in norms]
        # Since the values of all the parameters are the same within each example,
        # we can easily compute what the values of the gradients should be:
        expected_val = [
            jnp.sum(jnp.arange(self.batch_size) / div) for div in divisors
        ]
        expected_tree = jax.tree_unflatten(jax.tree_structure(self.params),
                                           expected_val)
        expected_tree = jax.tree_map(
            lambda val, p: jnp.broadcast_to(val, p.shape), expected_tree,
            self.params)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state, self.params)
            chex.assert_trees_all_close(updates, expected_tree, rtol=2e-7)
예제 #6
0
  def test_ema_accumulator(self, decay, debias):
    """Test ema_accumulator."""
    grads = jax.random.uniform(jax.random.PRNGKey(0), (10, 10))
    grads = {'updates': grads, 'variables': {}, 'moments': {}, 'output': None}

    nth_grads = preconditioner.nth_power(power=[1, 2])
    accumulator = preconditioner.ema_accumulator(decay, debias)

    grads, _ = nth_grads.update(grads, None)
    state = accumulator.init(grads['variables'])

    for i in range(5):
      moments, count = state
      grads = jax.random.uniform(jax.random.PRNGKey(i + 1), (10, 10))
      grads = {'updates': grads, 'variables': {}, 'moments': {}, 'output': None}
      grads, _ = nth_grads.update(grads, None)
      updates, state = accumulator.update(grads, state)

      actual = jax.tree_map(lambda g, t: (1 - decay) * g + decay * t,
                            grads['variables'], moments)
      if debias:
        count += jnp.array(1, dtype=jnp.int32)
        beta = jnp.array(1, dtype=jnp.int32) - decay ** count
        actual = jax.tree_map(lambda t: t / beta.astype(t.dtype), actual)  # pylint: disable=cell-var-from-loop

      chex.assert_trees_all_close(updates['moments'], actual)
예제 #7
0
  def test_precondition_by_rms(self, decay, eps, eps_root, debias):
    """Test precondition_by_rms."""

    actual_rms = transform.precondition_by_rms(
        decay=decay, eps=eps, eps_root=eps_root, debias=debias)

    decon_rms = preconditioner.preconditioner(preconditioner.nth_power,
                                              preconditioner.ema_accumulator,
                                              preconditioner.rexp_updater,
                                              {'power': 2}, {
                                                  'decay': decay,
                                                  'debias': debias
                                              }, {
                                                  'eps': eps,
                                                  'eps_root': eps_root,
                                              })

    params = jax.random.uniform(jax.random.PRNGKey(0), (10, 10))

    init = self.variant(decon_rms.init)
    update = self.variant(decon_rms.update)

    actual_state = actual_rms.init(params)
    decon_state = init(params)

    for i in range(5):
      grads = jax.random.uniform(jax.random.PRNGKey(i + 1), (10, 10))

      actual_updates, actual_state = actual_rms.update(grads, actual_state)
      decon_updates, decon_state = update(grads, decon_state)

      actual_params = optax.apply_updates(params, actual_updates)
      decon_params = optax.apply_updates(params, decon_updates)

      chex.assert_trees_all_close(actual_params, decon_params, atol=1e-4)
예제 #8
0
    def test_correctness(self):
        """Testing correctness via an independent flax.optim run."""

        target_solution = [
            {
                'w': jnp.array([0.40500003, 0.286])
            },
            {
                'w': jnp.array([0.255515, 0.106618])
            },
            {
                'w': jnp.array([0.31884143, 0.18260972])
            },
            {
                'w': jnp.array([0.40163627, 0.28196353])
            },
            {
                'w': jnp.array([0.43924114, 0.32708937])
            },
        ]
        optimizer = transform_chain(['nesterov'], [{
            'decay': 0.7
        }],
                                    learning_rate=0.01)
        results = _optimizer_loop(optimizer)
        for target, result in zip(target_solution, results):
            chex.assert_trees_all_close(target, result)
예제 #9
0
 def testLossIsNearZeroAtOrigin(self):
     # Check that the loss is near-zero when x is near-zero.
     _, loss, x, _, _, _, _, _ = self._precompute_lossfun_inputs()
     loss_near_zero = loss[jnp.abs(x) < 1e-5]
     chex.assert_trees_all_close(loss_near_zero,
                                 jnp.zeros_like(loss_near_zero),
                                 atol=1e-5)
예제 #10
0
    def test_equivalence(self):
        hb = transform_chain(
            ['precondition_by_rms', 'polyak_hb', 'add_decayed_weights'],
            [{
                'decay': 0.3
            }, {
                'decay': 0.5
            }, {
                'weight_decay': 0.1
            }],
            learning_rate=1.0)
        ema = transform_chain(
            ['precondition_by_rms', 'first_moment_ema', 'add_decayed_weights'],
            [{
                'decay': 0.3
            }, {
                'decay': 0.5
            }, {
                'weight_decay': 0.05
            }],
            learning_rate=2.0)

        targets = _optimizer_loop(hb)
        results = _optimizer_loop(ema)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
예제 #11
0
    def test_correctness(self):
        """Testing correctness via an independent flax.optim run."""

        target_solution = [
            {
                'w': jnp.array([0.65, 0.58000004])
            },
            {
                'w': jnp.array([0.26849997, 0.12220004])
            },
            {
                'w': jnp.array([0.09766498, -0.08280197])
            },
            {
                'w': jnp.array([0.17850482, 0.01420582])
            },
            {
                'w': jnp.array([0.38620475, 0.2634457])
            },
        ]
        optimizer = transform_chain(['polyak_hb'], [{
            'decay': 0.7
        }],
                                    learning_rate=0.01)
        results = _optimizer_loop(optimizer)
        for target, result in zip(target_solution, results):
            chex.assert_trees_all_close(target, result)
예제 #12
0
  def test_batch(self):
    """Test that batch layer is indeed ignored.

    Code taken from: https://github.com/google/flax/issues/932
    """
    key = jax.random.PRNGKey(0)
    x = jnp.ones((5, 4, 4, 3))
    y = jax.random.uniform(key, (5, 4, 4, 7))

    foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))
    tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask())

    @self.variant
    def train_step(params, x, y):
      y1, new_batch_stats = Foo(
          filters=7, train=True).apply(
              params, x, mutable=['batch_stats'])

      return jnp.abs(y - y1).sum(), new_batch_stats

    state = self.variant(tx.init)(foo_vars['params'])
    grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
    updates, state = self.variant(tx.update)(dict(grads['params']), state)

    chex.assert_trees_all_close(updates['BatchNorm_0'],
                                grads['params']['BatchNorm_0'])
예제 #13
0
 def testDerivativeIsMonotonicWrtX(self):
     # Check that the loss increases monotonically with |x|.
     _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()
     # This is just to suppress a warning below.
     d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))
     mask = jnp.isfinite(alpha) & (jnp.abs(d_x) >
                                   (300. * jnp.finfo(jnp.float32).eps))
     chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
예제 #14
0
    def test_stateless_with_tree_map_no_params(self):
        updates = {'linear': jnp.full((5, 3), 3.0)}

        opt = base.stateless_with_tree_map(lambda g, _: g * 2.0)
        state = opt.init(None)
        update_fn = self.variant(opt.update)
        new_updates, _ = update_fn(updates, state)
        expected_updates = {'linear': jnp.full((5, 3), 6.0)}
        chex.assert_trees_all_close(new_updates, expected_updates)
예제 #15
0
 def testLossIsQuadraticNearOrigin(self):
     # Check that the loss is well-approximated by a quadratic bowl when
     # |x| < scale
     _, loss, x, _, scale, _, _, _ = self._precompute_lossfun_inputs()
     mask = jnp.abs(x) < (0.5 * scale)
     loss_quad = 0.5 * jnp.square(x / scale)
     chex.assert_trees_all_close(loss_quad[mask],
                                 loss[mask],
                                 rtol=1e-5,
                                 atol=1e-2)
예제 #16
0
    def test_no_op(self):
        """Test no-op."""
        optimizer = transform_chain(
            ['nesterov'],
            masks=[lambda p: jax.tree_map(lambda x: x.ndim != 1, p)])
        params = {'w': jnp.array([1, 2, 3])}
        state = optimizer.init(params)
        update, state = optimizer.update(params, state, params)

        chex.assert_trees_all_close(params, update)
예제 #17
0
    def test_stateless_with_tree_map(self):
        params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1, ))}
        updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1, ), 2.0)}

        opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p)
        state = opt.init(params)
        update_fn = self.variant(opt.update)
        new_updates, _ = update_fn(updates, state, params)
        expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])}
        chex.assert_trees_all_close(new_updates, expected_updates)
예제 #18
0
    def test_debias_true(self):
        adam = transform_chain(['scale_by_adam'], [{'b1': 0.0}])
        precondition_by_rms = transform_chain(['precondition_by_rms'],
                                              [{
                                                  'debias': True
                                              }])
        targets = _optimizer_loop(adam)
        results = _optimizer_loop(precondition_by_rms)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
예제 #19
0
    def test_stateless_no_params(self):
        updates = {'linear': jnp.full((5, 3), 3.0)}

        @base.stateless
        def opt(g, _):
            return jax.tree_map(lambda g_: g_ * 2, g)

        state = opt.init(None)
        update_fn = self.variant(opt.update)
        new_updates, _ = update_fn(updates, state)
        expected_updates = {'linear': jnp.full((5, 3), 6.0)}
        chex.assert_trees_all_close(new_updates, expected_updates)
예제 #20
0
    def test_stateless(self):
        params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1, ))}
        updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1, ), 2.0)}

        @base.stateless
        def opt(g, p):
            return jax.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p)

        state = opt.init(params)
        update_fn = self.variant(opt.update)
        new_updates, _ = update_fn(updates, state, params)
        expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])}
        chex.assert_trees_all_close(new_updates, expected_updates)
예제 #21
0
    def testAlphaEqualsNegativeInfinity(self):
        # Check that alpha == -Infinity reproduces Welsch aka Leclerc loss.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = -float('inf')
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # Welsch/Leclerc loss.
        loss_true = (1. - np.exp(-0.5 * np.square(x / scale)))

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #22
0
    def testLossIsScaleInvariant(self):
        # Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale)
        (num_samples, loss, x, alpha, scale, _, _,
         _) = (self._precompute_lossfun_inputs())
        # Random log-normally distributed scalings in ~(0.2, 20)

        rng = random.PRNGKey(0)
        mult = jnp.maximum(0.2, jnp.exp(random.normal(rng,
                                                      shape=[num_samples])))

        # Compute the scaled loss.
        loss_scaled = general.lossfun(mult * x, alpha, mult * scale)
        chex.assert_trees_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)
예제 #23
0
    def testAlphaEqualsZero(self):
        # Check that alpha == 0 reproduces Cauchy aka Lorentzian loss.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = 0.
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # Cauchy/Lorentzian loss.
        loss_true = (np.log(0.5 * np.square(x / scale) + 1))

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #24
0
    def testAlphaEqualsNegativeTwo(self):
        # Check that alpha == -2 reproduces Geman-McClure loss.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = -2.
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # Geman-McClure loss.
        loss_true = (2. * np.square(x / scale) / (np.square(x / scale) + 4.))

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #25
0
    def testAlphaEqualsOne(self):
        # Check that alpha == 1 reproduces Charbonnier aka pseudo-Huber loss.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = 1.
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # Charbonnier loss.
        loss_true = (np.sqrt(np.square(x / scale) + 1) - 1)

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #26
0
    def testAlphaEqualsTwo(self):
        # Check that alpha == 2 reproduces L2 loss.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = 2.
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # L2 Loss.
        loss_true = 0.5 * np.square(x / scale)

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #27
0
    def testAlphaEqualsInfinity(self):
        # Check that alpha == Infinity takes the correct form.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = float('inf')
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # The true loss.
        loss_true = (jnp.exp(0.5 * jnp.square(x / scale)) - 1.)

        chex.assert_trees_all_close(loss, loss_true, atol=1e-4, rtol=1e-4)
예제 #28
0
    def test_debias_false(self):
        rms_prop = optax.scale_by_rms()
        precondition_by_rms = transform_chain(['precondition_by_rms'],
                                              [{
                                                  'eps': 0,
                                                  'eps_root': 1e-8,
                                                  'decay': 0.9,
                                                  'debias': False
                                              }])
        targets = _optimizer_loop(rms_prop)
        results = _optimizer_loop(precondition_by_rms)

        for target, result in zip(targets, results):
            chex.assert_trees_all_close(target, result)
예제 #29
0
    def testAlphaEqualsFour(self):
        # Check that alpha == 4 reproduces a quartic.
        x = np.linspace(-15, 15, 1000, dtype=np.float64)
        alpha = 4.
        scale = 1.7

        # Our loss.
        loss = self.variant(general.lossfun)(x, alpha, scale)

        # The true loss.
        loss_true = np.square(np.square(x / scale)) / 8 + np.square(
            x / scale) / 2

        chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
예제 #30
0
    def test_add_decayed_weights_with_mask(self):
        """Test mask is not added for add_decayed_weights if specified in hps."""
        class Foo(nn.Module):
            """Dummy model."""

            train: bool
            filters: int

            @nn.compact
            def __call__(self, x):
                x = nn.Conv(self.filters, (1, 1),
                            use_bias=False,
                            dtype=jnp.float32)(x)
                x = nn.BatchNorm(use_running_average=not self.train,
                                 momentum=0.9,
                                 epsilon=1e-5,
                                 dtype=jnp.float32)(x)
                return x

        tx = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4,
                        'mask': 'bias_bn'
                    }
                }
            }))
        key = jax.random.PRNGKey(0)
        x = jnp.ones((5, 4, 4, 3))
        y = jax.random.uniform(key, (5, 4, 4, 7))

        foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))

        @self.variant
        def train_step(params, x, y):
            y1, new_batch_stats = Foo(filters=7, train=True).apply(
                params, x, mutable=['batch_stats'])

            return jnp.abs(y - y1).sum(), new_batch_stats

        state = self.variant(tx.init)(foo_vars['params'])
        grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
        updates, state = self.variant(tx.update)(dict(grads['params']), state,
                                                 foo_vars['params'])

        chex.assert_trees_all_close(updates['BatchNorm_0'],
                                    grads['params']['BatchNorm_0'])