示例#1
0
 def single_update(i, opt_state, batch, rng):
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     return opt_update(
         i,
         backend.grad(loss_fun)(params, batch, predict_fun, rng),
         opt_state)
示例#2
0
 def single_update(i, opt_state, batch, state, rng):
     params, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = backend.grad(loss_fn, has_aux=True)
     grads, state = grad_fn(params, batch, predict_fn, state, rng)
     return optimizer.tree_update(i, grads, params, slots,
                                  opt_params), state, [subrng]
示例#3
0
 def single_update(i, opt_state, batch, rng):
     rng, subrng = jax_random.split(rng[0])
     params, opt_slots = opt_state
     return optimizer.tree_update(
         i,
         backend.grad(loss_fn)(params, batch, predict_fn, rng), params,
         opt_slots), [subrng]
示例#4
0
  def test_custom_id_grad(self):

    class IdWithIdGrad(base.Layer):

      def call(self, x, params, **kwargs):
        del params, kwargs
        return x, ()

      def new_parameters(self, input_shapes, input_dtype, rng):
        del input_shapes, input_dtype, rng
        return (), ()

      @property
      def has_custom_grad(self):
        return True

      def custom_grad(self, inputs, output, ct, params, state, **kwargs):
        return (inputs, ())

    layer = IdWithIdGrad()
    rng = backend.random.get_prng(0)
    params = ()
    input_shape = (9, 17)
    random_input = backend.random.uniform(rng, input_shape, minval=-1.0,
                                          maxval=1.0)
    f = lambda x: backend.numpy.mean(layer(x, params, rng=rng)[0])
    grad = backend.grad(f)(random_input)
    self.assertEqual(grad.shape, input_shape)  # Gradient for each input.
    self.assertEqual(sum(sum(grad)), sum(sum(random_input)))  # Same as input.
示例#5
0
    def test_custom_zero_grad(self):
        class IdWithZeroGrad(base.Layer):
            def call(self, x, params, **kwargs):
                del params, kwargs
                return x

            def new_parameters(self, input_shapes, input_dtype, rng):
                del input_shapes, input_dtype, rng
                return ()

            @property
            def has_custom_grad(self):
                return True

            def custom_grad(self, inputs, output, ct, params, **kwargs):
                return (backend.numpy.zeros_like(ct), None, None)

        layer = IdWithZeroGrad()
        rng = backend.random.get_prng(0)
        params = ()
        input_shape = (9, 17)
        random_input = backend.random.uniform(rng,
                                              input_shape,
                                              minval=-1.0,
                                              maxval=1.0)
        f = lambda x: backend.numpy.mean(layer(x, params, rng=rng))
        grad = backend.grad(f)(random_input)
        self.assertEqual(grad.shape, input_shape)  # Gradient for each input.
        self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
示例#6
0
  def test_custom_id_grad(self):

    class IdWithIdGrad(base.Layer):

      def forward(self, x, params=(), state=(), **kwargs):
        del kwargs
        return x, ()

      @property
      def has_backward(self):
        return True

      def backward(self, inputs, output, ct, params, state, **kwargs):
        return (inputs, ())

    layer = IdWithIdGrad()
    rng = backend.random.get_prng(0)
    input_shape = (9, 17)
    random_input = backend.random.uniform(rng, input_shape, minval=-1.0,
                                          maxval=1.0)
    layer.initialize_once(input_shape, random_input.dtype, rng)
    f = lambda x: backend.numpy.mean(layer(x))
    grad = backend.grad(f)(random_input)
    self.assertEqual(grad.shape, input_shape)  # Gradient for each input.
    self.assertEqual(sum(sum(grad)), sum(sum(random_input)))  # Same as input.
    def test_reformer_rng_consistency(self):
        with backend.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_shape = ((batch_size, 8), (batch_size, 8))
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = backend.random.get_prng(0)
            params, state = model.initialize_once(input_shape,
                                                  (np.int32, np.int32), rng)

            def dummy_loss_fn(params):
                inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2
                output = model(inputs, params=params, state=state, rng=rng)
                dummy_loss = backend.numpy.sum(output[0])
                return dummy_loss

            grad_fn = backend.grad(dummy_loss_fn)
            grads = grad_fn(params)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
示例#8
0
 def single_update(i, opt_state, batch, rng):
     rng, subrng = jax_random.split(rng[0])
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     return opt_update(
         i,
         backend.grad(loss_fun)(params, batch, predict_fun, rng),
         opt_state), [subrng]
示例#9
0
 def mapped_update(i, opt_state, batch, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     rng, subrng = jax_random.split(rng)
     params, opt_slots = opt_state
     grads = backend.grad(loss_fn)(params, batch, predict_fn, rng)
     grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
     return optimizer.tree_update(i, grads, params, opt_slots), subrng
示例#10
0
 def mapped_update(i, opt_state, batch, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = num_devices.
   _, opt_update = optimizer(lr_fun)
   params = trax_opt.get_params(opt_state)
   grads = backend.grad(loss_fun)(params, batch, predict_fun, rng)
   grads = jax.tree_util.tree_map(
       lambda g: lax.psum(g, "batch"), grads)
   return opt_update(i, grads, opt_state)
示例#11
0
 def mapped_update(i, opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     params, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng)
     grad_fn = backend.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(params, batch, state, rng)
     grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
     return optimizer.tree_update(i, grads, params, slots,
                                  opt_params), state, subrng