def single_update(i, opt_state, batch, rng): _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) return opt_update( i, backend.grad(loss_fun)(params, batch, predict_fun, rng), opt_state)
def single_update(i, opt_state, batch, state, rng): params, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = backend.grad(loss_fn, has_aux=True) grads, state = grad_fn(params, batch, predict_fn, state, rng) return optimizer.tree_update(i, grads, params, slots, opt_params), state, [subrng]
def single_update(i, opt_state, batch, rng): rng, subrng = jax_random.split(rng[0]) params, opt_slots = opt_state return optimizer.tree_update( i, backend.grad(loss_fn)(params, batch, predict_fn, rng), params, opt_slots), [subrng]
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def call(self, x, params, **kwargs): del params, kwargs return x, () def new_parameters(self, input_shapes, input_dtype, rng): del input_shapes, input_dtype, rng return (), () @property def has_custom_grad(self): return True def custom_grad(self, inputs, output, ct, params, state, **kwargs): return (inputs, ()) layer = IdWithIdGrad() rng = backend.random.get_prng(0) params = () input_shape = (9, 17) random_input = backend.random.uniform(rng, input_shape, minval=-1.0, maxval=1.0) f = lambda x: backend.numpy.mean(layer(x, params, rng=rng)[0]) grad = backend.grad(f)(random_input) self.assertEqual(grad.shape, input_shape) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def test_custom_zero_grad(self): class IdWithZeroGrad(base.Layer): def call(self, x, params, **kwargs): del params, kwargs return x def new_parameters(self, input_shapes, input_dtype, rng): del input_shapes, input_dtype, rng return () @property def has_custom_grad(self): return True def custom_grad(self, inputs, output, ct, params, **kwargs): return (backend.numpy.zeros_like(ct), None, None) layer = IdWithZeroGrad() rng = backend.random.get_prng(0) params = () input_shape = (9, 17) random_input = backend.random.uniform(rng, input_shape, minval=-1.0, maxval=1.0) f = lambda x: backend.numpy.mean(layer(x, params, rng=rng)) grad = backend.grad(f)(random_input) self.assertEqual(grad.shape, input_shape) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, params=(), state=(), **kwargs): del kwargs return x, () @property def has_backward(self): return True def backward(self, inputs, output, ct, params, state, **kwargs): return (inputs, ()) layer = IdWithIdGrad() rng = backend.random.get_prng(0) input_shape = (9, 17) random_input = backend.random.uniform(rng, input_shape, minval=-1.0, maxval=1.0) layer.initialize_once(input_shape, random_input.dtype, rng) f = lambda x: backend.numpy.mean(layer(x)) grad = backend.grad(f)(random_input) self.assertEqual(grad.shape, input_shape) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def test_reformer_rng_consistency(self): with backend.use_backend('jax'): vocab_size = 16 batch_size = 1 input_shape = ((batch_size, 8), (batch_size, 8)) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = backend.random.get_prng(0) params, state = model.initialize_once(input_shape, (np.int32, np.int32), rng) def dummy_loss_fn(params): inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2 output = model(inputs, params=params, state=state, rng=rng) dummy_loss = backend.numpy.sum(output[0]) return dummy_loss grad_fn = backend.grad(dummy_loss_fn) grads = grad_fn(params) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def single_update(i, opt_state, batch, rng): rng, subrng = jax_random.split(rng[0]) _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) return opt_update( i, backend.grad(loss_fun)(params, batch, predict_fun, rng), opt_state), [subrng]
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) params, opt_slots = opt_state grads = backend.grad(loss_fn)(params, batch, predict_fn, rng) grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads) return optimizer.tree_update(i, grads, params, opt_slots), subrng
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) grads = backend.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map( lambda g: lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. params, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(params, batch, state, rng) grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads) return optimizer.tree_update(i, grads, params, slots, opt_params), state, subrng