def fbo(inputs, weights, state, slots, opt_params, rng, step, grads): """FBO of the layer.""" # We need a layer pure_fn but only for inputs and weights. def pure_fn_without_state_and_rng(x, w): return layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the reduced pure fn. activations, vjp_fn, new_state = fastmath.vjp( pure_fn_without_state_and_rng, inputs, weights, has_aux=True) # In the loss layer, set gradients to 1 with the dtype of activations=loss. if grads is None and stats_name is not None: grads = jnp.ones((), dtype=activations.dtype) # The vjp function returns gradients with respect to inputs and weights. grads_inputs, grads_weights = vjp_fn(grads) # For non-trainable layers, return the calculated arguments. if _is_empty_tuple(weights): stats = {} if stats_name is not None: stats[stats_name] = activations return weights, new_state, slots, grads_inputs, stats # In multi-device setting, average gradients from multiple devices. if n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the optimizer. new_weights, new_slots, stats = optimizer.tree_update( step, grads_weights, weights, slots, opt_params) if stats_name is not None: stats[stats_name] = activations return new_weights, new_state, new_slots, grads_inputs, stats
def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), rng=None): """Backward pass: computes the inverse of a layer and propagates gradients. While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients. Args: output: Output activations; can be a (possibly nested) tuple. grad: gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output. weights: layer weights state: start state new_state: updated state computed by the forward pass rng: Single-use random number generator (JAX PRNG key). Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights. """ def _do_forward(x, weights): old_weights, old_state, old_rng = self.weights, self.state, self._rng self.state, self._rng = state, rng self.weights = weights res = self.forward(x) self.weights, self.state, self._rng = old_weights, old_state, old_rng return res reconstructed_x = self.reverse(output, weights, state, new_state, rng) _, vjpfun = fastmath.vjp(_do_forward, reconstructed_x, weights) x_weights_grad = vjpfun(grad) return reconstructed_x, x_weights_grad
def loss_fbo(inputs, weights, state, slots, opt_params, rng, step): """FBO of the final loss layer.""" # We need a loss layer pure_fn but only for inputs and weights. def loss_pure_fn_without_state_and_rng(x, w): return loss_layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the reduced loss pure fn. loss, vjp_fn, new_state = fastmath.vjp( loss_pure_fn_without_state_and_rng, inputs, weights, has_aux=True) # The vjp function returns gradients with respect to inputs and weights. # Since loss is scalar and there are no other layers, run it at 1.0. grads_inputs, grads_weights = vjp_fn(jnp.ones((), dtype=loss.dtype)) # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the loss optimizer, which is the last one since it's the last layer. new_weights, new_slots, stats = self._optimizers[-1].tree_update( step, grads_weights, weights, slots, opt_params) stats['loss'] = loss return new_weights, new_state, new_slots, grads_inputs, stats
def first_fbo(inputs, weights, state, slots, opt_params, rng, step, grads): """FBO of the first layer.""" # We need the first layer's pure_fn but only for inputs and weights. def first_layer_pure_fn_without_state_and_rng(x, w): return first_layer.pure_fn(x, w, state, rng) # Calculate vector-Jacobian product of the reduced first layer pure fn. activations_after_first_layer, vjp_fn, new_state = fastmath.vjp( first_layer_pure_fn_without_state_and_rng, inputs, weights, has_aux=True) del activations_after_first_layer # unused # The vjp function returns gradients with respect to inputs and weights. _, grads_weights = vjp_fn(grads) # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the first layer optimizer, which is the first one. new_weights, new_slots, stats = self._optimizers[0].tree_update( step, grads_weights, weights, slots, opt_params) return new_weights, new_state, new_slots, stats
def forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True): """Performs batched forward and/or backward passes. Args: inputs: inputs to the attention layer weights: weights for the attention layer state: state of the attention layer rng: PRNG key for the layer (shared across all examples and heads) output_grad: gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff `output_grad` is not None. compute_output: bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output). update_state: bool: whether to return an updated layer state. Returns: A tuple (output, new_state, inputs_grad, weights_grad). - output is not None iff compute_output is True - new_state is not None iff update_state is True - inputs_grad & weights_grad are not None iff output_grad is not None """ # We need a layer pure_fn but only for inputs and weights. def pure_fn_without_state_and_rng(x, w): return layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the layer pure_fn. output, vjp_fn, new_state = fastmath.vjp(pure_fn_without_state_and_rng, inputs, weights, has_aux=True) output = output if compute_output else None new_state = new_state if update_state else None # The vjp function returns gradients with respect to inputs and weights. if output_grad is not None: grads_inputs, grads_weights = vjp_fn(output_grad) else: grads_inputs, grads_weights = None, None return (output, new_state, grads_inputs, grads_weights)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn( x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct ) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack( attn_inputs_ct, stack_ct, self.attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct,) compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack( compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct,) stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct ) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x,) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct,) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self._compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): state_to_pass = state[0] # old_state # _replace_second_time is currently used exclusively in _RememberInReverse # layer to combat numerical instability in Reformer2 when quantizing # the mask in SparseFF. def _replace_second_time(stt, nstt): if (isinstance(stt, tuple) and len(stt) == 2 and isinstance(stt[1], dict) and 'running_second_time' in stt[1]): return (nstt[0], {'running_second_time_yes': ()}) elif isinstance(stt, (tuple, list)): assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) return type(stt)([ _replace_second_time(s, ns) for s, ns in zip(stt, nstt) ]) else: return stt state_to_pass = _replace_second_time(state_to_pass, new_state[0]) res, _ = self._compute_residual.pure_fn(x, weights=weights, state=state_to_pass, rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self._attention_layer is not None: n_differentiable = min(len(res), self._attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) stack_ct = accumulator_output_ct if self._attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct, self._attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self._compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct, ) compute_residual_ct = compute_residual_ct[: n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct, ) def _add(x, y): # `None` is for TFNP backend, which uses `None` as the gradient of # int/bool instead of an array of dtype `float0`. if x is None or x.dtype == jax.float0: return y if y is None or y.dtype == jax.float0: return x return x + y stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg( _add, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x, ) + context if self._attention_layer is None: weights_ct = (compute_residual_weights_ct, ) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)