def single_update(i, opt_state, batch, state, rng): weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, [subrng]
def test_custom_zero_grad(self): class IdWithZeroGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ()) layer = IdWithZeroGrad() rng = math.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = math.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = math.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, ct, weights, state, new_state, **kwargs): return (inputs, ()) layer = IdWithIdGrad() rng = math.random.get_prng(0) input_signature = ShapeDtype((9, 17)) random_input = math.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: np.mean(layer(x)) grad = math.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def test_reformer_rng_consistency(self): with math.use_backend('jax'): vocab_size = 16 batch_size = 1 input_sd = ShapeDtype((batch_size, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = math.random.get_prng(0) weights, state = model.init(input_signature) def dummy_loss_fn(weights): inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2 output = model(inputs, weights=weights, state=state, rng=rng) dummy_loss = math.numpy.sum(output[0]) return dummy_loss grad_fn = math.grad(dummy_loss_fn) grads = grad_fn(weights) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = math.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = math.grad(compute_mock_loss, has_aux=True)(weights) new_weights = math.nested_map_multiarg(lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits
def _run_one_step(self): """Updates model weights and optimizer slots by running one step/batch.""" optimizer = self._task.optimizer # TODO(jonni): figure out why JAX tracer needs the following line. weights = self._model.weights opt_params = optimizer._init_opt_params # pylint: disable=protected-access batch = self._task.next_batch() model_with_loss = tl.Serial(self._model, self._task.loss_layer) loss_as_fn_of_weights = lambda w: model_with_loss(batch, weights=w) gradients = math.grad(loss_as_fn_of_weights)(model_with_loss.weights) self._model.weights, optimizer.slots = optimizer.tree_update( self.current_step(), gradients, weights, optimizer.slots, opt_params)
def __init__(self, model, task, eval_task=None, output_dir=None, checkpoint_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, don't save any checkpoints. """ self._task = task self._model_in_training = tl.Serial(model, task.loss_layer) self._eval_task = eval_task self._output_dir = output_dir self._checkpoint_at = checkpoint_at or _never self._step = None batch_signature = shapes.signature(task.sample_batch) # Initialize the model and the optimizer; discard the return values # (model weights/state, optimizer slots/params), since they're available # from the model and optimizer objects. _, _ = self._model_in_training.init(batch_signature) _, _ = task.optimizer.tree_init(self._model_in_training.weights) self._gradients_and_state_fn = ( math.jit( math.grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (gradients, state) if eval_task is not None: model_with_metrics = _model_with_metrics(model, eval_task) self._eval_weights = model_with_metrics.weights[ 1] # just the eval part self._eval_state = model_with_metrics.state[ 1] # just the eval part self._metrics_fn = math.jit(model_with_metrics.pure_fn)
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: math.psum(g, 'batch') / math.psum(1.0, 'batch'), grads) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, subrng