Example #1
0
  def new_weights_and_state(self, input_signature):
    stack = input_signature[1:]
    if len(stack) == 1:
      stack = stack[0]

    inputs = _inputs_from_stack(self.compute_residual, stack)
    weights, state = self.compute_residual.init(inputs)
    outputs, _ = self.compute_residual._forward_abstract(inputs)
    stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

    if self.attention_layer is None:
      return (weights,), (state,)
    else:
      inputs = _inputs_from_stack(self.attention_layer, stack)
      attn_weights, attn_state = self.attention_layer.init(inputs)
      return (weights, attn_weights), (state, attn_state)
Example #2
0
    def forward_with_state(self, xs, weights, state, rng):
        self._validate_forward_inputs(xs)
        (step, layers_state) = state
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _split_rngs(rng, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, (step + 1, layers_state))

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(layers_state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(layers_state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
            w_steps = float(self._skipping_warmup_steps)
            warmup = np.maximum(0.0,
                                (w_steps - step.astype(np.float32)) / w_steps)
            # low is the minimum number of layers to *not* skip, from n_layers to 0
            low = warmup * float(n_layers)
            # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
            # because (high - n_layers) / high is the probability we're not skipping
            # (after warmup); so high - n_layers = high - high * skip_fraction
            high = float(n_layers) / self._skip_fraction
            # We want the same rng0 on all cores.
            if math.device_count() > 1:
                rng0 = math.psum(rng0, 'batch')
            n_forward_layers = random.uniform(rng0, (), np.float32, low, high)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, layers_state,
                                    rngs):
            inputs = _inputs_from_stack(layer, stack)
            outputs, s = math.cond(  # Skip (do identity) if > n_forward_layers.
                pred=(math.lt(cur_layer_idx, n_forward_layers)),
                true_operand=(inputs, p, s, rng),  # This tuple is t below.
                true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])),  # pylint: disable=cell-var-from-loop
                false_operand=(inputs, p, s, rng),
                false_fun=(lambda t: (t[0], t[2])),  # return (inputs, state)
            )
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, (step + 1, new_state)
Example #3
0
    def forward_with_state(self,
                           xs,
                           weights=tl.EMPTY_WEIGHTS,
                           state=tl.EMPTY_STATE,
                           **kwargs):
        self._validate_forward_inputs(xs)
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _pop_rng_and_split(kwargs, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, state)

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            n_forward_layers = random.uniform(rng0, (), np.float32, 0.0,
                                              n_layers)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, state, rngs):
            inputs = _inputs_from_stack(layer, stack)
            outputs, s = math.cond(  # Skip (do identity) if > n_forward_layers.
                pred=(math.lt(cur_layer_idx, n_forward_layers)),
                true_operand=(inputs, p, s, rng),  # This tuple is t below.
                true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])),  # pylint: disable=cell-var-from-loop
                false_operand=(inputs, p, s, rng),
                false_fun=(lambda t: (t[0], t[2])),  # return (inputs, state)
            )
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, new_state
Example #4
0
  def _set_input_signature_recursive(self, input_signature):
    """Sets input signatures for this layer and sublayers, recursively.

    Args:
      input_signature: A `ShapeDtype` instance (if this layer takes one input)
          or a list/tuple of `ShapeDtype` instances.
    """
    self._input_signature = input_signature

    # Infer shapes and dtypes (signatures) through the successive sublayers.
    stack = input_signature[1:]
    for layer in self.sublayers:
      inputs = _inputs_from_stack(layer, stack)
      layer._set_input_signature_recursive(inputs)
      outputs, _ = layer._forward_abstract(inputs)
      stack = _outputs_onto_stack(layer, outputs, stack)
Example #5
0
  def forward(self, xs):
    rngs = _split_rngs(self.rng, len(self.sublayers))
    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    self.state = tuple(new_state)
    return stack
Example #6
0
    def forward_with_state(self,
                           xs,
                           weights=tl.EMPTY_WEIGHTS,
                           state=tl.EMPTY_STATE,
                           **kwargs):
        self._validate_forward_inputs(xs)
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _pop_rng_and_split(kwargs, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, state)

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            n_forward_layers = random.uniform(rng0, (), np.float32, 0.0,
                                              n_layers)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, state, rngs):
            inputs = _inputs_from_stack(layer, stack)
            # TODO(chowdhery): port to jax.lax.cond once it has a JVP rule.
            outputs, s = layer._forward_internal(inputs, p, s, rng)  # pylint: disable=protected-access
            condition = math.lt(cur_layer_idx,
                                n_forward_layers).astype(np.float32)
            outputs = condition * outputs + (1 - condition) * inputs
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, new_state
Example #7
0
  def forward_with_state(self, xs, weights=base.EMPTY_WEIGHTS,
                         state=base.EMPTY_STATE, **kwargs):
    rngs = _pop_rng_and_split(kwargs, len(self.sublayers))

    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, weights, state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    return stack, new_state
Example #8
0
  def init_weights_and_state(self, input_signature):
    weights = []
    states = []
    # In the code below, stack, inputs, and outputs are abstract (shapes and
    # dtypes), but weights and states are non-abstract actual values.
    stack = input_signature
    for sublayer in self.sublayers:
      inputs = _inputs_from_stack(sublayer, stack)
      weights_or_cache_marker, state_or_cache_marker = (
          sublayer.init(inputs, use_cache=True))
      outputs, _ = sublayer._forward_abstract(inputs)
      stack = _outputs_onto_stack(sublayer, outputs, stack)

      weights.append(weights_or_cache_marker)
      states.append(state_or_cache_marker)
    self.state = (jnp.array(0, dtype=jnp.int32), states)
    self.weights = weights
Example #9
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rngs = _pop_rng_and_split(kwargs, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = _inputs_from_stack(self.compute_residual, stack)
    outputs, compute_residual_vjpfun, outputs_aux = jax.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = _inputs_from_stack(self.attention_layer, stack)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self.attention_layer.forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = _outputs_onto_stack(
          self.attention_layer, attn_inputs_ct, stack_ct,
          self.attention_layer.n_out, self.attention_layer.n_in)

    compute_residual_ct = _inputs_from_stack(
        self.compute_residual, stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = _outputs_onto_stack(
        self.compute_residual, compute_residual_inputs_ct, stack_ct,
        self.compute_residual.n_out, self.compute_residual.n_in)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + jax.tree_multimap(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)
Example #10
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         **kwargs):
        rngs = _pop_rng_and_split(kwargs, len(self.sublayers))

        accumulator_output, *context = output
        context = tuple(context)
        accumulator_output_ct, *context_ct = ct
        context_ct = tuple(context_ct)

        # Forward pass through self.compute_residual
        def call_compute_residual(x, weights):
            res, _ = self.compute_residual._forward_internal(  # pylint: disable=protected-access
                x,
                weights=weights,
                state=state[0],
                rng=rngs[0])
            return res

        stack = context
        inputs = _inputs_from_stack(self.compute_residual, stack)
        outputs, compute_residual_vjpfun = jax.vjp(call_compute_residual,
                                                   inputs, weights[0])
        stack = _outputs_onto_stack(self.compute_residual, outputs, stack)

        # TODO(kitaev): handle the case of multiple outputs from
        # self.compute_residual.
        stack_ct = accumulator_output_ct
        if self.attention_layer is None:
            residual = stack[0] if isinstance(stack, (tuple, list)) else stack
        else:
            inputs = _inputs_from_stack(self.attention_layer, stack)
            (residual, _, attn_inputs_ct,
             attn_weights_ct) = self.attention_layer.forward_and_or_backward(
                 inputs,
                 weights[1],
                 new_state[1],
                 output_grad=accumulator_output_ct,
                 compute_output=True,
                 update_state=False)
            stack_ct = _outputs_onto_stack(self.attention_layer,
                                           attn_inputs_ct, stack_ct,
                                           self.attention_layer.n_out,
                                           self.attention_layer.n_in)

        compute_residual_ct = _inputs_from_stack(self.compute_residual,
                                                 stack_ct,
                                                 self.compute_residual.n_out)
        (compute_residual_inputs_ct, compute_residual_weights_ct
         ) = compute_residual_vjpfun(compute_residual_ct)
        stack_ct = _outputs_onto_stack(self.compute_residual,
                                       compute_residual_inputs_ct, stack_ct,
                                       self.compute_residual.n_out,
                                       self.compute_residual.n_in)
        if not isinstance(stack_ct, (tuple, list)):
            stack_ct = (stack_ct, )
        stack_ct = (accumulator_output_ct, ) + jax.tree_multimap(
            lambda x, y: x + y, context_ct[:len(stack_ct)],
            stack_ct) + context_ct[len(stack_ct):]

        reconstructed_x = accumulator_output - residual
        stack = (reconstructed_x, ) + context
        if self.attention_layer is None:
            weights_ct = (compute_residual_weights_ct, )
        else:
            weights_ct = (compute_residual_weights_ct, attn_weights_ct)
        return stack, (stack_ct, weights_ct)