def test_call_and_grad(self): layer_partial = tl.Serial( tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), sparsity.Favor(d_feature=4, n_heads=2), tl.Select([0], n_in=2), ) layer = tl.Serial( tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), sparsity.Favor(d_feature=4, n_heads=2), tl.Select([0], n_in=2), tl.WeightedCategoryCrossEntropy(), ) x = np.ones((1, 2), dtype=np.int32) w = np.ones_like(x).astype(np.float32) x_sig = shapes.signature(x) w_sig = shapes.signature(w) layer_partial.init(x_sig) y = layer_partial(x) self.assertEqual(y.shape, (1, 2, 4)) layer.init((x_sig, x_sig, w_sig)) y = layer((x, x, w)) self.assertEqual(y.shape, ()) state = layer.state rng = fastmath.random.get_prng(0) fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[ 0] g = fastmath.grad(fwd)(layer.weights, (x, x, w)) self.assertEqual(g[0][1][0].shape, (3, 4))
def test_custom_zero_grad(self, backend): class IdWithZeroGrad(tl.Layer): def forward(self, x): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ()) with fastmath.use_backend(backend): layer = IdWithZeroGrad() rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = fastmath.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = fastmath.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def single_update(weights_and_slots, i, opt_params, batch, state, rng): weights, slots = weights_and_slots rng, subrng = jax_random.split(rng[0]) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, [subrng]
def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = fastmath.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad( compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits
def __init__(self, model, task, eval_model=None, eval_task=None, output_dir=None, checkpoint_at=None, eval_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ self._task = task self._model = model self._model_in_training = tl.Serial(model, task.loss_layer) self._eval_model = model if eval_model is None else eval_model self._eval_task = eval_task self._output_dir = os.path.expanduser(output_dir) if output_dir else None default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint) self._checkpoint_at = checkpoint_at or default_fn self._eval_at = eval_at or default_fn if eval_task is None: self._eval_at = _never self._step = 0 batch_signature = shapes.signature(task.sample_batch) self._batch_signature = batch_signature # Initialize the model and the optimizer; discard the return values # (model weights/state, optimizer slots/params), since they're available # from the model and optimizer objects. _, _ = self._model_in_training.init(batch_signature) _, _ = task.optimizer.tree_init(self._model_in_training.weights) self._gradients_and_state_fn = ( fastmath.jit(fastmath.grad(self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (gradients, state) if eval_task is not None: model_with_metrics = _model_with_metrics(self._eval_model, eval_task) self._eval_weights = model_with_metrics.weights[1] # just the eval part self._eval_state = model_with_metrics.state[1] # just the eval part self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
def test_causal_call_and_grad(self): layer = tl.Serial(tl.Dense(4), sparsity.CausalFavor(d_feature=4, n_heads=2), tl.L2Loss()) x = np.random.uniform(size=(1, 2, 4)).astype(np.float32) w = np.ones_like(x) x_sig = shapes.signature(x) w_sig = shapes.signature(w) layer.init((x_sig, x_sig, w_sig)) y = layer((x, x, w)) self.assertEqual(y.shape, ()) state = layer.state rng = fastmath.random.get_prng(0) fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[ 0] g = fastmath.grad(fwd)(layer.weights, (x, x, w)) self.assertEqual(g[0][0].shape, (4, 4))
def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots rng, subrng = jax_random.split(rng) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: ( # pylint: disable=g-long-lambda fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')), grads) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, subrng
def test_custom_id_grad(self, backend): # After changes to some fastmath.custom_vjp functions (made so that we could # land JAX PR #4008), this test started failing, with a ValueError from # TensorFlow: ValueError: ('custom_gradient function expected to return', 2, # 'gradients but returned', 3, 'instead.'. # TODO(mattjj,lukaszkaiser): revive this test after landing #4008 if backend == fastmath.Backend.TFNP: raise unittest.SkipTest( 'temporarily skipping test so that we can ' 'land https://github.com/google/jax/pull/4008') class IdWithIdGrad(tl.Layer): def forward(self, x): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (inputs, ()) with fastmath.use_backend(backend): layer = IdWithIdGrad() rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = fastmath.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = fastmath.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.