def single_update(i, opt_state, batch, state, rng): weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, [subrng]
def test_reformer_rng_consistency(self): with backend.use_backend('jax'): vocab_size = 16 batch_size = 1 input_sd = ShapeDtype((batch_size, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = backend.random.get_prng(0) weights, state = model.initialize_once(input_signature) def dummy_loss_fn(weights): inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2 output = model(inputs, weights=weights, state=state, rng=rng) dummy_loss = backend.numpy.sum(output[0]) return dummy_loss grad_fn = backend.grad(dummy_loss_fn) grads = grad_fn(weights) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, params=(), state=(), **kwargs): del kwargs return x, () @property def has_backward(self): return True def backward(self, inputs, output, ct, params, state, **kwargs): return (inputs, ()) layer = IdWithIdGrad() rng = backend.random.get_prng(0) input_shape = (9, 17) random_input = backend.random.uniform(rng, input_shape, minval=-1.0, maxval=1.0) layer.initialize_once(input_shape, random_input.dtype, rng) f = lambda x: backend.numpy.mean(layer(x)) grad = backend.grad(f)(random_input) self.assertEqual(grad.shape, input_shape) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, ct, weights, state, new_state, **kwargs): return (inputs, ()) layer = IdWithIdGrad() rng = backend.random.get_prng(0) input_signature = ShapeDtype((9, 17)) random_input = backend.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: backend.numpy.mean(layer(x)) grad = backend.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) grads = jax.tree_util.tree_map( lambda g: backend.psum(g, 'batch'), grads) return optimizer.tree_update( i, grads, weights, slots, opt_params), state, subrng
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads) return optimizer.tree_update( i, grads, weights, slots, opt_params), state, subrng