コード例 #1
0
    def testUtilityClipGrads(self):
        g = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
        norm = optimizers.l2_norm(g)

        ans = optimizers.clip_grads(g, 1.1 * norm)
        expected = g
        self.assertAllClose(ans, expected, check_dtypes=False)

        ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
        expected = 0.9 * norm
        self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #2
0
    def update(self, params, opt_state, batch: util.Transition):
        """The actual update function."""
        (_, logs), grads = jax.value_and_grad(self._loss, has_aux=True)(params,
                                                                        batch)

        grad_norm_unclipped = optimizers.l2_norm(grads)
        updates, updated_opt_state = self._opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        weight_norm = optimizers.l2_norm(params)
        logs.update({
            'grad_norm_unclipped': grad_norm_unclipped,
            'weight_norm': weight_norm,
        })
        return params, updated_opt_state, logs
コード例 #3
0
 def testUtilityNorm(self):
     x0 = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
     norm = optimizers.l2_norm(x0)
     expected = np.sqrt(np.sum(np.ones(2 + 3 + 4)**2))
     self.assertAllClose(norm, expected, check_dtypes=False)