Exemple #1
0
    def test_one_minus(self):
        """Test that we appropriately process / remove one_minus_ hps."""
        tx = from_hparams(
            ml_collections.ConfigDict(
                {'0': {
                    'element': 'nesterov',
                    'hps': {
                        'decay': 0.9
                    }
                }}))
        tx_one_minus = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'nesterov',
                    'hps': {
                        'one_minus_decay': 0.1
                    }
                }
            }))

        params = {'a': 1.}

        state = tx.init(params)
        updates, state = tx.update(params, state, params)
        result = optax.apply_updates(params, updates)

        state = tx_one_minus.init(params)
        updates, state = tx_one_minus.update(params, state, params)
        result_one_minus = optax.apply_updates(params, updates)

        chex.assert_trees_all_equal(result, result_one_minus)
Exemple #2
0
 def test_empty_mask(self):
     from_hparams(
         ml_collections.ConfigDict(
             {'0': {
                 'element': 'nesterov',
                 'hps': {
                     'decay': 0.9
                 }
             }}))
Exemple #3
0
    def test_add_decayed_weights(self):
        """Test no mask gets added for add_decayed_weights."""
        tx_no_mask = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'nesterov',
                    'hps': {
                        'one_minus_decay': 0.1,
                    }
                },
                '1': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4
                    }
                }
            }))
        tx_none_mask = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'nesterov',
                    'hps': {
                        'one_minus_decay': 0.1,
                    }
                },
                '1': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4,
                        'mask': None
                    }
                }
            }))

        params = {'a': 1.}
        state = tx_no_mask.init(params)
        updates, state = tx_no_mask.update(params, state, params)
        result_no_mask = optax.apply_updates(params, updates)

        state = tx_none_mask.init(params)
        updates, state = tx_none_mask.update(params, state, params)
        result_none_mask = optax.apply_updates(params, updates)

        chex.assert_trees_all_equal(result_no_mask, result_none_mask)
Exemple #4
0
    def test_add_decayed_weights_with_mask(self):
        """Test mask is not added for add_decayed_weights if specified in hps."""
        class Foo(nn.Module):
            """Dummy model."""

            train: bool
            filters: int

            @nn.compact
            def __call__(self, x):
                x = nn.Conv(self.filters, (1, 1),
                            use_bias=False,
                            dtype=jnp.float32)(x)
                x = nn.BatchNorm(use_running_average=not self.train,
                                 momentum=0.9,
                                 epsilon=1e-5,
                                 dtype=jnp.float32)(x)
                return x

        tx = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4,
                        'mask': 'bias_bn'
                    }
                }
            }))
        key = jax.random.PRNGKey(0)
        x = jnp.ones((5, 4, 4, 3))
        y = jax.random.uniform(key, (5, 4, 4, 7))

        foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))

        @self.variant
        def train_step(params, x, y):
            y1, new_batch_stats = Foo(filters=7, train=True).apply(
                params, x, mutable=['batch_stats'])

            return jnp.abs(y - y1).sum(), new_batch_stats

        state = self.variant(tx.init)(foo_vars['params'])
        grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
        updates, state = self.variant(tx.update)(dict(grads['params']), state,
                                                 foo_vars['params'])

        chex.assert_trees_all_close(updates['BatchNorm_0'],
                                    grads['params']['BatchNorm_0'])
Exemple #5
0
 def test_empty_hps(self):
     from_hparams(ml_collections.ConfigDict({'0': {'element': 'nesterov'}}))