Пример #1
0
 def single_update(i, opt_state, batch, state, rng):
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = backend.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, [subrng]
Пример #2
0
    def test_reformer_rng_consistency(self):
        with backend.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_sd = ShapeDtype((batch_size, 8), np.int32)
            input_signature = (input_sd, input_sd)
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = backend.random.get_prng(0)
            weights, state = model.initialize_once(input_signature)

            def dummy_loss_fn(weights):
                inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
                output = model(inputs, weights=weights, state=state, rng=rng)
                dummy_loss = backend.numpy.sum(output[0])
                return dummy_loss

            grad_fn = backend.grad(dummy_loss_fn)
            grads = grad_fn(weights)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
Пример #3
0
  def test_custom_id_grad(self):

    class IdWithIdGrad(base.Layer):

      def forward(self, x, params=(), state=(), **kwargs):
        del kwargs
        return x, ()

      @property
      def has_backward(self):
        return True

      def backward(self, inputs, output, ct, params, state, **kwargs):
        return (inputs, ())

    layer = IdWithIdGrad()
    rng = backend.random.get_prng(0)
    input_shape = (9, 17)
    random_input = backend.random.uniform(rng, input_shape, minval=-1.0,
                                          maxval=1.0)
    layer.initialize_once(input_shape, random_input.dtype, rng)
    f = lambda x: backend.numpy.mean(layer(x))
    grad = backend.grad(f)(random_input)
    self.assertEqual(grad.shape, input_shape)  # Gradient for each input.
    self.assertEqual(sum(sum(grad)), sum(sum(random_input)))  # Same as input.
Пример #4
0
    def test_custom_id_grad(self):
        class IdWithIdGrad(base.Layer):
            def forward(self, x, weights):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, ct, weights, state, new_state,
                         **kwargs):
                return (inputs, ())

        layer = IdWithIdGrad()
        rng = backend.random.get_prng(0)
        input_signature = ShapeDtype((9, 17))
        random_input = backend.random.uniform(rng,
                                              input_signature.shape,
                                              minval=-1.0,
                                              maxval=1.0)
        layer.init(input_signature)
        f = lambda x: backend.numpy.mean(layer(x))
        grad = backend.grad(f)(random_input)
        self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
        self.assertEqual(sum(sum(grad)),
                         sum(sum(random_input)))  # Same as input.
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng
Пример #6
0
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng