def forward_with_state(self, xs, weights, state, rng): self._validate_forward_inputs(xs) (step, layers_state) = state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. return (xs, (step + 1, layers_state)) # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers if n_layers != 1 and len(weights) != n_layers: raise ValueError( 'number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError( 'length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using math.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) warmup = np.maximum(0.0, (w_steps - step.astype(np.float32)) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if math.device_count() > 1: rng0 = math.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), np.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = math.cond( # Skip (do identity) if > n_forward_layers. pred=(math.lt(cur_layer_idx, n_forward_layers)), true_operand=(inputs, p, s, rng), # This tuple is t below. true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])), # pylint: disable=cell-var-from-loop false_operand=(inputs, p, s, rng), false_fun=(lambda t: (t[0], t[2])), # return (inputs, state) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 return stack, (step + 1, new_state)
def forward_with_state(self, inputs, weights=(), state=(), rng=None): if self._n_sections == 1: results = self._layer(inputs, weights=weights, state=state, rng=rng) else: rngs = _split_rngs(rng, len(inputs)) results = [self._layer(x, weights=weights, state=state, rng=r) for x, r in zip(inputs, rngs)] results = tuple(results) # TODO(kitaev): think about how to merge state across copies in the map. return results, self._layer.state
def forward(self, xs): rngs = _split_rngs(self.rng, len(self.sublayers)) accumulator, *context = xs stack = context = tuple(context) new_state = [] for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = layer.pure_fn(inputs, w, s, rng) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) residual = stack[0] if isinstance(stack, (tuple, list)) else stack output = accumulator + residual stack = (output,) + context self.state = tuple(new_state) return stack
def forward_with_state(self, xs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator, *context = xs stack = context = tuple(context) new_state = [] for layer, w, s, rng in zip(self.sublayers, weights, state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = layer.pure_fn(inputs, w, s, rng) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) residual = stack[0] if isinstance(stack, (tuple, list)) else stack output = accumulator + residual stack = (output,) + context return stack, new_state
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn( x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = _inputs_from_stack(self.compute_residual, stack) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = _outputs_onto_stack(self.compute_residual, outputs, stack) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = _inputs_from_stack(self.attention_layer, stack) (residual, _, attn_inputs_ct, attn_weights_ct ) = self.attention_layer.forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = _outputs_onto_stack( self.attention_layer, attn_inputs_ct, stack_ct, self.attention_layer.n_out, self.attention_layer.n_in) compute_residual_ct = _inputs_from_stack( self.compute_residual, stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct,) compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = _outputs_onto_stack( self.compute_residual, compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out, self.compute_residual.n_in) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct,) stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct ) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x,) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct,) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)