Ejemplo n.º 1
0
    def testLossAndGradientsAreFinite(self):
        # Test that the loss and its approximation both give finite losses and
        # derivatives everywhere that they should for a wide range of values.
        num_samples = 100000
        rng = random.PRNGKey(0)

        # Normally distributed inputs.
        rng, key = random.split(rng)
        x = random.normal(key, shape=[num_samples])

        # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1
        # to ensure that we hit the special cases at 0, 2.
        rng, key = random.split(rng)
        alpha = jnp.round(
            random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
            10) / 10.

        # Random log-normally distributed values in approx (1e-5, 100000):
        rng, key = random.split(rng)
        scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5

        fn = self.variant(general.lossfun)
        loss = fn(x, alpha, scale)
        d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)),
                                          [0, 1, 2])(x, alpha, scale))

        for v in [loss, d_x, d_alpha, d_scale]:
            chex.assert_tree_all_finite(v)
    def test_distributed_shampoo(
        self,
        best_effort_memory_usage_reduction,
        symmetric_block_size,
        block_statistics,
    ):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1,
            32,
            batch_axis_name='batch',
            preconditioning_compute_steps=2,
            best_effort_memory_usage_reduction=
            best_effort_memory_usage_reduction,
        )
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)

        def _update(unused_batch):
            return transform_fn(self.per_step_updates, state, params)

        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        pmap_fn = jax.pmap(_update, axis_name='batch')

        updates, state = pmap_fn(jnp.array([1.0]))
        chex.assert_tree_all_finite((params, updates, state))
Ejemplo n.º 3
0
def test_scalers(constructor: Callable[[], GradientTransformation[Any, Any]]) -> None:
    params = init_params

    scaler = constructor()
    init_fn = variant(scaler.init)
    transform_fn = variant(scaler.update)

    state = init_fn(params)
    chex.assert_tree_all_finite(state)

    updates, state = transform_fn(per_step_updates, state, params)
    chex.assert_tree_all_finite((params, updates, state))
    tree_map(lambda *args: chex.assert_equal_shape(args), params, updates)
Ejemplo n.º 4
0
    def test_scale_by_factored_rms(self):
        params = self.init_params

        scaler = factorized.scale_by_factored_rms()
        init_fn = self.variant(scaler.init)
        transform_fn = self.variant(scaler.update)

        state = init_fn(params)
        chex.assert_tree_all_finite(state)

        updates, state = transform_fn(self.per_step_updates, state, params)
        chex.assert_tree_all_finite((params, updates, state))
        chex.assert_tree_all_equal_shapes(params, updates)
Ejemplo n.º 5
0
    def testGradientsAreFiniteWithAllInputs(self, alpha):
        x_half = jnp.concatenate(
            [jnp.exp(jnp.linspace(-80, 80, 1001)),
             jnp.array([jnp.inf])])
        x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half])
        scale = jnp.full_like(x, 1.)

        fn = self.variant(lambda x, s: general.lossfun(x, alpha, s))
        loss = fn(x, scale)
        d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale)

        for v in [loss, d_x, d_scale]:
            chex.assert_tree_all_finite(v)
Ejemplo n.º 6
0
  def test_scalers(self, scaler_constr):
    params = self.init_params

    scaler = scaler_constr()
    init_fn = self.variant(scaler.init)
    transform_fn = self.variant(scaler.update)

    state = init_fn(params)
    chex.assert_tree_all_finite(state)

    updates, state = transform_fn(self.per_step_updates, state, params)
    chex.assert_tree_all_finite((params, updates, state))
    jax.tree_multimap(lambda *args: chex.assert_equal_shape(args), params,
                      updates)
    def test_distributed_shampoo_no_pmap(self):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1,
            32,
            batch_axis_name=None,
            preconditioning_compute_steps=2,
        )
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)
        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        updates, state = transform_fn(self.per_step_updates, state, params)
        chex.assert_tree_all_finite((params, updates, state))
Ejemplo n.º 8
0
  def test_sm3_basic(self):
    params = self.init_params

    optim = sm3.sm3(0.1, 0.9, 0.999)
    init_fn = self.variant(optim.init)
    transform_fn = self.variant(optim.update)

    def _update(unused_batch):
      return transform_fn(self.per_step_updates, state, params)
    state = init_fn(params)
    chex.assert_tree_all_finite(state)
    pmap_fn = jax.pmap(_update, axis_name='batch')

    updates, state = pmap_fn(jnp.array([1.0]))
    chex.assert_tree_all_finite((params, updates, state))
Ejemplo n.º 9
0
    def test_distributed_shampoo(self):
        params = self.init_params

        optim = distributed_shampoo.distributed_shampoo(
            0.1, 32, batch_axis_name='batch', preconditioning_compute_steps=2)
        init_fn = self.variant(optim.init)
        transform_fn = self.variant(optim.update)

        def _update(unused_batch):
            return transform_fn(self.per_step_updates, state, params)

        state = init_fn(params)
        chex.assert_tree_all_finite(state)
        pmap_fn = jax.pmap(_update, axis_name='batch')

        updates, state = pmap_fn(jnp.array([1.0]))
        chex.assert_tree_all_finite((params, updates, state))