Esempio n. 1
0
def _fast_inference_init_state(input_signature, buffer_length):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        depth = shape[-1]
        return np.zeros((batch_size, buffer_length, depth), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    mask = np.zeros((batch_size, 1, buffer_length))
    seq_indices = np.zeros((batch_size, ), dtype=np.int32)
    return (k, v, mask, seq_indices)
Esempio n. 2
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
Esempio n. 3
0
 def new_weights(self, input_signature):
   d_feature = input_signature.shape[-1]
   if self._transform == 'diag':
     # Initialize it to a small value because JAX has a bug in softplus.
     scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4
     weights = scale_isoftplus
   elif self._transform == 'any':
     ortho = trax.layers.initializers.OrthogonalInitializer()
     weights = ortho((d_feature, d_feature), self.rng)
   else:
     weights = layer_base.EMPTY_WEIGHTS
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng
   return weights
Esempio n. 4
0
def NewPositionalEncoding(x, positions=None, **kwargs):
    """Implements new positional encoding."""
    del kwargs
    x_length = np.shape(x)[1]
    pos = np.array(positions)[np.newaxis, :x_length, :]
    pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
    return pos
Esempio n. 5
0
 def new_weights_and_state(self, input_signature):
     if self._mode == 'predict':
         batch_size = input_signature.shape[0]
         return layer_base.EMPTY_WEIGHTS, np.zeros((batch_size, ),
                                                   dtype=np.int32)
     else:
         return layer_base.EMPTY_WEIGHTS, layer_base.EMPTY_STATE
Esempio n. 6
0
def _layer_norm_weights(input_signature, **unused_kwargs):
  """Helper: create layer norm parameters."""
  features = input_signature.shape[-1]
  scale = np.ones(features, dtype=input_signature.dtype)
  bias = np.zeros(features, dtype=input_signature.dtype)
  weights = (scale, bias)
  return weights
Esempio n. 7
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Esempio n. 8
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Makes encoder-decoder mask from decoder input and a padding mask."""
    decoder_input, padding_mask = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
Esempio n. 9
0
 def create_state_unbatched(self, input_signature, rng):
     if isinstance(input_signature, (tuple, list)):
         input_signature = input_signature[0]
     buckets = np.zeros(self.n_hashes * input_signature.shape[0],
                        dtype=np.int32)
     # TODO(kitaev): storing RNG in the state is a HACK.
     return (buckets, rng)
Esempio n. 10
0
def _layer_norm_weights(input_signature):
  """Helper: create layer norm parameters."""
  features = input_signature.shape[-1]
  scale = np.ones(features)
  bias = np.zeros(features)
  weights = (scale, bias)
  return weights
Esempio n. 11
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Esempio n. 12
0
 def F(encoder_activations, decoder_activations, input_tokens):
     keys = values = encoder_activations
     queries = decoder_activations
     # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
     mask = (input_tokens != 0)
     # We need to add axes to the mask for attention heads and decoder length.
     mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
     # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
     mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
     return queries, keys, values, mask
Esempio n. 13
0
 def create_state_unbatched(self, input_signature, rng):
   if isinstance(input_signature, (tuple, list)):
     input_signature = input_signature[0]
   buckets = np.zeros(self.n_hashes * input_signature.shape[0], dtype=np.int32)
   # The `rng` argument passed to forward_unbatched is shared across all
   # examples and heads. This facilitates using broadcasted dropout, which
   # saves memory and hasn't been shown to hurt model quality. Even though the
   # same sharing is likely to be safe when selecting random hash functions
   # for LSH, we haven't run experiments to demonstrate this. To be on the safe
   # side we include a per-head RNG in the state for the purpose of doing LSH.
   return (buckets, rng)
Esempio n. 14
0
 def new_weights_and_state(self, input_signature):
   """Helper to initialize batch norm weights."""
   axis = self._axis
   axis = (axis,) if np.isscalar(axis) else axis
   input_shape = input_signature.shape
   shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
   beta = np.zeros(shape, dtype='float32') if self._center else ()
   gamma = np.ones(shape, dtype='float32') if self._scale else ()
   def get_stats_axis(i, d):
     if i in axis:
       return 1
     else:
       return d
   stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
   running_mean = np.zeros(stats_shape, dtype=np.float32)
   running_var = np.ones(stats_shape, dtype=np.float32)
   n_batches = np.zeros((), dtype=np.int64)
   weights = (beta, gamma)
   state = (running_mean, running_var, n_batches)
   return weights, state
Esempio n. 15
0
    def forward_with_state(self, inputs, weights, state, rng=None):
        """Computes this layer's output as part of a forward pass through the model.

    Args:
      inputs: Layer inputs (subclasses may use different inputs)
      weights: Layer weights
      state: Complete state of the layer
      rng: PRNG key

    Returns:
      A tuple (output, new_state).
    """
        if not self.use_reference_code:
            # By default, an efficient, batched implementation is used.
            output, new_state, _, _ = self.forward_and_or_backward(
                inputs, weights, state, compute_output=True, update_state=True)
            return output, new_state

        # The reference implementation below provides a more readable overview of
        # what this class does. It's not optimized, however, and should only be used
        # when testing this class for correctness.
        if not isinstance(inputs, (tuple, list)):
            inputs = (inputs, )
        batch_size = int(inputs[0].shape[0])
        seqlen = inputs[0].shape[-2]
        d_model = inputs[0].shape[-1]
        output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)]
        new_state = []
        for example_idx in range(batch_size):
            for head_idx in range(self.n_heads):
                # pylint: disable=cell-var-from-loop
                single_inputs = jax.tree_map(lambda x: x[example_idx], inputs)
                single_weights = jax.tree_map(lambda w: w[head_idx], weights)
                single_state = jax.tree_map(
                    lambda s: s[example_idx * self.n_heads + head_idx], state)
                # pylint: enable=cell-var-from-loop
                single_out, single_new_state = self.forward_unbatched(
                    *single_inputs,
                    weights=single_weights,
                    state=single_state,
                    update_state=True)
                new_state.append(single_new_state)
                output_accum[
                    example_idx] = output_accum[example_idx] + single_out

        output = np.stack(output_accum, 0)
        if new_state and jax.tree_leaves(new_state[0]):
            new_state = jax.tree_multimap(lambda *s: np.stack(s, 0),
                                          *new_state)
        else:
            new_state = state
        return output, new_state
Esempio n. 16
0
    def test_weights_and_state_signature(self):
        class MyLayer(base.Layer):
            def init_weights_and_state(self, input_signature):
                self.weights = jnp.zeros((2, 3))
                self.state = jnp.ones(input_signature.shape)

            def forward(self, inputs):
                return self.weights + self.state

        layer = MyLayer()
        w, s = layer.weights_and_state_signature(jnp.zeros((3, 4)))
        self.assertEqual(w.shape, (2, 3))
        self.assertEqual(s.shape, (3, 4))
Esempio n. 17
0
  def new_weights(self, input_signature):
    # Usually (B, W, H, C)
    shape = input_signature.shape
    num_channels = shape[-1]

    gamma = np.ones((num_channels,), dtype=np.float32)
    beta = np.zeros((num_channels,), dtype=np.float32)

    epsilon_l = base.EMPTY_WEIGHTS
    if self._learn_epsilon:
      epsilon_l = (self._init_learnt_epsilon,)

    return gamma, beta, epsilon_l
Esempio n. 18
0
    def init_weights_and_state(self, input_signature):
        """Helper to initialize batch norm weights and state."""
        axis = self._axis
        axis = (axis, ) if jnp.isscalar(axis) else axis
        input_shape = input_signature.shape
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        # TODO(jonni): Should beta and gamma match the dtype in the input signature?
        beta = jnp.zeros(shape, dtype='float32') if self._center else ()
        gamma = jnp.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = jnp.zeros(stats_shape, dtype=jnp.float32)
        running_var = jnp.ones(stats_shape, dtype=jnp.float32)
        n_batches = jnp.zeros((), dtype=jnp.int64)
        self.weights = (beta, gamma)
        self.state = (running_mean, running_var, n_batches)
Esempio n. 19
0
 def new_weights(self, input_signature):
   d_feature = input_signature.shape[-1]
   pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
   position = np.arange(0, self._max_len)[:, np.newaxis]
   div_term = np.exp(
       np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
   pe[:, 0::2] = np.sin(position * div_term)
   pe[:, 1::2] = np.cos(position * div_term)
   pe = pe[np.newaxis, :, :]  # [1, self._max_len, d_feature]
   weights = jnp.array(pe)  # Trainable parameters, initialized above.
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
   return weights
Esempio n. 20
0
 def new_weights_and_state(self, input_signature):
   d_feature = input_signature.shape[-1]
   pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32)
   position = onp.arange(0, self._max_len)[:, onp.newaxis]
   div_term = onp.exp(
       onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
   pe[:, 0::2] = onp.sin(position * div_term)
   pe[:, 1::2] = onp.cos(position * div_term)
   pe = pe[onp.newaxis, :, :]  # [1, self._max_len, d_feature]
   weights = np.array(pe)  # These are trainable parameters, initialized above.
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     state = np.zeros((batch_size,), dtype=np.int32)
   else:
     state = base.EMPTY_STATE
   return weights, state
Esempio n. 21
0
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Esempio n. 22
0
 def policy_inputs(self, trajectory, values):
     """Create inputs to policy model from a TrajectoryNp and values."""
     # How much TD to use is determined by the added policy slice length,
     # as the policy batches need to be this much longer to calculate TD.
     advantages = self._advantage_estimator(
         rewards=trajectory.rewards,
         returns=trajectory.returns,
         values=values,
         dones=trajectory.dones,
         gamma=self._task.gamma,
         n_extra_steps=self._added_policy_slice_length,
     )
     # Observations should be the same length as advantages - so if we are
     # using n_extra_steps, we need to trim the length to match.
     obs = trajectory.observations[:, :advantages.shape[1]]
     act = trajectory.actions[:, :advantages.shape[1]]
     mask = trajectory.mask[:, :advantages.
                            shape[1]]  # Mask to zero-out padding.
     if trajectory.dist_inputs is not None:
         dist_inputs = trajectory.dist_inputs[:, :advantages.shape[1]]
     else:
         dist_inputs = jnp.zeros(advantages.shape +
                                 (self._policy_dist.n_inputs, ))
     # Shape checks to help debugging.
     if len(advantages.shape) != 2:
         raise ValueError('Advantages are expected to have shape ' +
                          '[batch_size, length], got: %s' %
                          str(advantages.shape))
     if act.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of actions should be the same as in '
             'advantages, %s != %s' % (act.shape[0:2], advantages.shape))
     if obs.shape[0:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of observations should be the same '
             'as in advantages, %s != %s' %
             (obs.shape[0:2], advantages.shape))
     if dist_inputs.shape[:2] != advantages.shape:
         raise ValueError(
             'First 2 dimensions of dist_inputs should be the same '
             'as in advantages, %s != %s' %
             (dist_inputs.shape[:2], advantages.shape))
     if mask.shape != advantages.shape:
         raise ValueError('Mask and advantages shapes should be the same'
                          ', %s != %s' % (mask.shape, advantages.shape))
     return (obs, act, advantages, dist_inputs, mask)
Esempio n. 23
0
 def batches_stream(self):
   """Use the RLTask self._task to create inputs to the value model."""
   for np_trajectory in self._task.trajectory_batch_stream(
       self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]):
     if np_trajectory.dist_inputs is not None:
       old_dist_inputs = np_trajectory.dist_inputs
     else:
       old_dist_inputs = jnp.zeros(
           np_trajectory.rewards.shape + (self._policy_dist.n_inputs,)
       )
     old_log_probs = self._policy_dist.log_prob(
         old_dist_inputs, np_trajectory.actions
     )
     # Insert an extra depth dimension, so the target shape is consistent with
     # the network output shape.
     yield (np_trajectory.observations,         # Inputs to the value model.
            np_trajectory.returns[:, :, None],
            np_trajectory.actions,
            old_log_probs,
            np_trajectory.mask)
Esempio n. 24
0
  def _run_value_model(self, observations, dist_inputs):
    if dist_inputs is None:
      dist_inputs = jnp.zeros(
          observations.shape[:2] + (self._policy_dist.n_inputs,)
      )

    actions = None
    if self._q_value:
      dist_inputs = np.broadcast_to(
          dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape
      )
      actions = self._policy_dist.sample(dist_inputs)
      inputs = (observations, actions)
    else:
      inputs = (observations,)

    values = self._value_eval_model(
        inputs, n_accelerators=1
    ) * self._value_network_scale
    values = np.squeeze(values, axis=-1)  # Remove the singleton depth dim.

    return (values, actions)
Esempio n. 25
0
  def _get_initial_state(self, inputs, targets_prefix, batch_size):
    """Get initial state for beam search."""
    if targets_prefix is None:
      prompt = np.zeros((batch_size, 1), dtype=np.int32)
    else:
      prompt = np.pad(
          targets_prefix[:, :-1], ((0, 0), (1, 0)), mode='constant')

    # Get state prior to running the encoder or incorporating targets_prefix
    if inputs is None:
      signature = ShapeDtype((batch_size, 1), prompt.dtype)
    else:
      signature = (ShapeDtype(inputs.shape, inputs.dtype),
                   ShapeDtype((batch_size, 1), prompt.dtype))
    # Trax's model.init is stateful as opposed to functional. Calling it on an
    # already-existing model instance doesn't work.
    # TODO(lukaszkaiser): add purely functional init to Trax.
    _, initial_state = self.model(mode='predict').init(signature)

    # Incorporate encoder and prompt into state
    _, prompted_state = self.model_infer.pure_fn(
        prompt if inputs is None else (inputs, prompt),
        self.model_weights,
        initial_state,
        jax.random.PRNGKey(0))
    state_structure = jax.tree_structure(prompted_state)

    if targets_prefix is not None:
      initial_state = prompted_state
    elif self.encoder_idx is not None:
      initial_state = (tuple(prompted_state[:self.encoder_idx])
                       + tuple(initial_state[self.encoder_idx:]))

    # Fix tree structure of the state (there's a tuple vs. list mismatch)
    initial_state = jax.tree_unflatten(
        state_structure, jax.tree_leaves(initial_state))

    return initial_state
Esempio n. 26
0
 def init_weights_and_state(self, input_signature):
   self.weights = np.zeros(input_signature.shape[-1])
Esempio n. 27
0
  def decode(self, inputs=None, targets_prefix=None, batch_size=None):
    """Performs decoding for a batch of examples.

    Args:
      inputs: [batch_size, encoder_input_length] int32 numpy array: Inputs to
        the encoder portion of the model. If the model does not have an encoder,
        leave this set to None.
      targets_prefix: [batch_size, target_prefix_length] int32 numpy array:
        Optional prefix to initialize the decoder with. The start token should
        never be included in the prefix. Note that all examples in the batch
        must use the same prefix length.
      batch_size: If both inputs and targets_prefix are None, the batch_size
        argument is required and will determine the batch size for decoding.
        Otherwise, this argument serves as an optional hint for the batch size
        that the inputs should be padded out to before running inference. The
        XLA computation for inference needs to be re-jitted every time a new
        batch size is encountered, so passing a constant batch_size argument can
        speed up inference by avoiding recompilation.

    Returns:
      Tuple of:
        [batch_size, beam_size, max_decode_len] top-scoring sequences
        [batch_size, beam_size] beam-search scores.
      The highest-scoring sequence will be at index -1 along the beam_size axis.
    """
    n_devices = trax.math.device_count()
    if inputs is not None and targets_prefix is not None:
      pad_to = batch_size
      batch_size = inputs.shape[0]
      assert targets_prefix.shape[0] == batch_size
    elif inputs is not None:
      pad_to = batch_size
      batch_size = inputs.shape[0]
    elif targets_prefix is not None:
      pad_to = batch_size
      batch_size = targets_prefix.shape[0]
    else:
      pad_to = None

    if pad_to is None:
      pad_amount = (n_devices - (batch_size % n_devices)) % n_devices
    else:
      assert pad_to % n_devices == 0, (
          'When specifying batch_size for the purposes of padding,'
          'batch_size must be divisible by the number of devices.')
      pad_amount = pad_to - batch_size
      assert pad_amount >= 0

    if inputs is not None:
      if pad_amount:
        inputs = onp.concatenate([inputs] + [inputs[0:1]] * pad_amount, 0)
      inputs = inputs.reshape((n_devices, -1) + inputs.shape[1:])
    if targets_prefix is not None:
      if pad_amount:
        targets_prefix = onp.concatenate(
            [targets_prefix] + [targets_prefix[0:1]] * pad_amount, 0)
      targets_prefix = targets_prefix.reshape(
          (n_devices, -1) + targets_prefix.shape[1:])

    seqs, scores = self._jit_beam_search(
        inputs, targets_prefix, (batch_size + pad_amount) // n_devices,
        dummy=np.zeros(n_devices))
    seqs = onp.asarray(seqs)
    scores = onp.asarray(scores)
    seqs = seqs.reshape((-1,) + seqs.shape[2:])
    scores = scores.reshape((-1,) + scores.shape[2:])
    seqs = seqs[:, :, 1:]  # Strip start token
    if pad_amount:
      seqs = seqs[:batch_size]
      scores = scores[:batch_size]
    return seqs, scores
Esempio n. 28
0
 def create_state_unbatched(self, input_signature, rng):
     # TODO(kitaev): storing RNG in the state is a HACK.
     buckets = np.zeros(self.n_hashes * input_signature.shape[0],
                        dtype=np.int32)
     return (buckets, rng)
Esempio n. 29
0
    def forward_and_or_backward(self,
                                inputs,
                                weights,
                                state,
                                output_grad=None,
                                compute_output=True,
                                update_state=True):
        """Performs batched forward and/or backward passes.

    See `forward_with_state` for a reference implementation of what this layer
    does. The reference implementation is not very efficient, however, and this
    method provides a more performant version.

    Args:
      inputs: inputs to the attention layer
      weights: weights for the attention layer
      state: state of the attention layer
      output_grad: gradient of the loss wrt the output of the layer, or None.
        This function performs the backward pass iff `output_grad` is not None.
      compute_output: bool: whether to return the output of the forward pass
        (for example, a pure backwards pass does not need to return the output).
      update_state: bool: whether to return an updated layer state.
    Returns:
      A tuple (output, new_state, inputs_grad, weights_grad).
      - output is not None iff compute_output is True
      - new_state is not None iff update_state is True
      - inputs_grad & weights_grad are not None iff output_grad is not None

    Notes regarding the implementation:
    (a) Multiple heads or examples are batched together. There are three
        different regimes possible: one head at a time (for long sequences and
        expensive attention types), several attention heads at a time (for
        long sequences but less-expensive attention types), and several
        examples at a time (for large batches of shorter sequences). For the
        time being, each of these regimes has its own code.
    (b) Python loops produce large computation graphs when jitted, so the
        default is to use a JAX loop instead.
    (c) No intermediate quantities are cached for the backward pass. Instead,
        the forward pass is re-computed when doing backprop. This approach is
        often called "checkpointing" or "rematerialization". When not all
        examples or heads fit in memory simultaneously, the implementation
        should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse
        memory locality. I don't think JAX autodiff can synthesize [FW-BW-1]
        automatically, so the looping for the backward pass is done manually.

        [FW-BW-1] for example, head in zip(examples, heads):
                    forward(example, head)
                    backward(example, head)  # uses intermediates from forward

        [FW-BW-2] for example, head in zip(examples, heads):
                    forward(example, head)
                  for example, head in zip(examples, heads):
                    backward(example, head)
    """
        # TODO(kitaev): support non-differentiable inputs (for enc-dec attn masking)
        # TODO(kitaev): support RNGs (needed for dropout and LSH). Currently LSH
        #     hacks around this by storing an RNG in its state
        # TODO(kitaev): profile ~4% speed drop compared to previous implementation
        #     in some conditions. Other conditions (e.g. the enwik8 model) appear
        #     to have the same overall training speed.
        # TODO(b/148460708): reduce memory usage further
        # TODO(kitaev): there should be a higher-level API (like vmap) that does
        #     batching, instead of needing 3 separate manual implementations here.

        have_single_input = not isinstance(inputs, (tuple, list))
        if have_single_input:
            inputs = (inputs, )
        batch_size = int(inputs[0].shape[0])
        seqlen = inputs[0].shape[-2]
        d_model = inputs[0].shape[-1]

        compute_grad = (output_grad is not None)
        assert compute_output or compute_grad, 'No work to perform!'

        # Adjust degree of parallelism based on the batch size.
        n_parallel_heads = batch_size * self.n_heads
        if self.n_parallel_heads and self.n_parallel_heads < n_parallel_heads:
            n_parallel_heads = self.n_parallel_heads

        def tree_update(tree, indices, new_values):
            return jax.tree_multimap(
                lambda x, y: jax.ops.index_update(x, jax.ops.index[indices], y
                                                  ), tree, new_values)

        def tree_add(tree, indices, new_values):
            return jax.tree_multimap(
                lambda x, y: jax.ops.index_add(x, jax.ops.index[indices], y),
                tree, new_values)

        if n_parallel_heads == 1:

            def run_inner(idx, loop_val):
                """Runs one slice of attention (for a single head)."""
                o_all, s_all, i_ct_all, w_ct_all = loop_val
                example_idx = idx // self.n_heads
                head_idx = idx % self.n_heads

                i_h = jax.tree_map(lambda x: x[example_idx], inputs)
                w_h = jax.tree_map(lambda w: w[head_idx], weights)
                s_h = jax.tree_map(lambda s: s[idx], state)

                def forward_fn(i_h, w_h):
                    return self.forward_unbatched(
                        *i_h,
                        weights=w_h,
                        state=jax.lax.stop_gradient(s_h),
                        update_state=update_state)

                if compute_grad:
                    o_h, backward_fn, s_h = jax.vjp(forward_fn,
                                                    i_h,
                                                    w_h,
                                                    has_aux=True)
                    ct_h = output_grad[example_idx]
                    assert o_h.shape == ct_h.shape
                    i_ct_h, w_ct_h = backward_fn(ct_h)
                else:
                    o_h, s_h = forward_fn(i_h, w_h)

                if compute_output:
                    o_all = jax.ops.index_add(o_all, example_idx, o_h)
                if update_state:
                    s_all = tree_update(s_all, idx, s_h)
                if compute_grad:
                    i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h)
                    w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h)
                return (o_all, s_all, i_ct_all, w_ct_all)
        elif n_parallel_heads < self.n_heads:
            assert self.n_heads % n_parallel_heads == 0

            def run_inner(idx, loop_val):
                """Runs one slice of attention (multiple heads, but one example)."""
                o_all, s_all, i_ct_all, w_ct_all = loop_val
                idx = idx * self.n_parallel_heads
                example_idx = idx // self.n_heads
                head_idx_lo = idx % self.n_heads
                # Use iota here instead of np.arange, because np.arange will fail to
                # infer that the slice size is a compile-time constant.
                head_range = head_idx_lo + jax.lax.iota(
                    np.int32, n_parallel_heads)
                state_range = idx + jax.lax.iota(np.int32, n_parallel_heads)

                i_mh = jax.tree_map(lambda x: x[example_idx], inputs)
                w_mh = jax.tree_map(lambda w: w[head_range], weights)
                s_mh = jax.tree_map(lambda s: s[state_range], state)

                def forward_unbatched(i_h, w_h, s_h):
                    return self.forward_unbatched(*i_h,
                                                  weights=w_h,
                                                  state=s_h,
                                                  update_state=update_state)

                def forward_fn(i_mh, w_mh):
                    o_mh, new_s_mh = jax.vmap(forward_unbatched,
                                              in_axes=(None, 0, 0),
                                              out_axes=0)(i_mh, w_mh, s_mh)
                    o_mh = o_mh.sum(0)
                    return o_mh, new_s_mh

                if compute_grad:
                    o_mh, backward_fn, s_mh = jax.vjp(forward_fn,
                                                      i_mh,
                                                      w_mh,
                                                      has_aux=True)
                    ct_mh = output_grad[example_idx]
                    assert o_mh.shape == ct_mh.shape
                    i_ct_mh, w_ct_mh = backward_fn(ct_mh)
                else:
                    o_mh, s_mh = forward_fn(i_mh, w_mh)

                if compute_output:
                    o_all = jax.ops.index_add(o_all, example_idx, o_mh)
                if update_state:
                    s_all = tree_update(s_all, state_range, s_mh)
                if compute_grad:
                    i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh)
                    w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh)
                return (o_all, s_all, i_ct_all, w_ct_all)
        else:
            assert n_parallel_heads % self.n_heads == 0

            def forward_single_example(i_x, w_all, s_x):
                def forward_unbatched(i_h, w_h, s_h):
                    return self.forward_unbatched(*i_h,
                                                  weights=w_h,
                                                  state=s_h,
                                                  update_state=update_state)

                o_x, s_x = jax.vmap(forward_unbatched,
                                    in_axes=(None, 0, 0),
                                    out_axes=(0, 0))(i_x, w_all, s_x)
                o_x = o_x.sum(0)
                return o_x, s_x

            def run_inner(idx, loop_val):
                """Runs one slice of attention (all heads for one or more examples)."""
                o_all, s_all, i_ct_all, w_ct_all = loop_val
                idx = idx * n_parallel_heads
                example_idx_lo = idx // self.n_heads
                # Use iota here instead of np.arange, because np.arange will fail to
                # infer that the slice size is a compile-time constant.
                example_range = example_idx_lo + jax.lax.iota(
                    np.int32, n_parallel_heads // self.n_heads)
                state_range = idx + jax.lax.iota(np.int32, n_parallel_heads)

                i_mex = jax.tree_map(lambda x: x[example_range], inputs)
                s_mex = jax.tree_map(
                    lambda s: np.reshape(
                        s[state_range],  # pylint: disable=g-long-lambda
                        (-1, self.n_heads) + s.shape[1:]),
                    state)

                def forward_fn(i_mex, w_all):
                    o_mex, new_s_mex = jax.vmap(forward_single_example,
                                                in_axes=(0, None, 0),
                                                out_axes=(0, 0))(i_mex, w_all,
                                                                 s_mex)
                    new_s_mex = jax.tree_map(
                        lambda s: np.reshape(s, (n_parallel_heads, ) + s.shape[
                            2:]), new_s_mex)
                    return o_mex, new_s_mex

                if compute_grad:
                    o_mex, backward_fn, s_mex = jax.vjp(forward_fn,
                                                        i_mex,
                                                        weights,
                                                        has_aux=True)
                    ct_mex = output_grad[example_range]
                    assert o_mex.shape == ct_mex.shape
                    i_ct_mex, w_ct_mex = backward_fn(ct_mex)
                else:
                    o_mex, s_mex = forward_fn(i_mex, weights)

                if compute_output:
                    o_all = jax.ops.index_add(o_all,
                                              jax.ops.index[example_range],
                                              o_mex)
                if update_state:
                    s_all = tree_update(s_all, state_range, s_mex)
                if compute_grad:
                    i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex)
                    w_ct_all = jax.tree_multimap(
                        lambda old_all, delta_all: old_all + delta_all,
                        w_ct_all, w_ct_mex)
                return (o_all, s_all, i_ct_all, w_ct_all)

        o_all = s_all = i_ct_all = w_ct_all = None
        if compute_output:
            o_all = np.zeros((batch_size, seqlen, d_model),
                             dtype=inputs[0].dtype)
        if update_state:
            s_all = state
        if compute_grad:
            # TODO(kitaev): no gradients for non-float inputs
            i_ct_all = jax.tree_map(np.zeros_like, inputs)
            w_ct_all = jax.tree_map(np.zeros_like, weights)

        loop_val = (o_all, s_all, i_ct_all, w_ct_all)

        assert (batch_size * self.n_heads) % n_parallel_heads == 0
        loop_hi = (batch_size * self.n_heads) // n_parallel_heads
        if self.use_python_loop or loop_hi == 1:
            for idx in range(loop_hi):
                loop_val = run_inner(idx, loop_val)
        else:
            loop_val = jax.lax.fori_loop(0, loop_hi, run_inner, loop_val)

        if have_single_input and compute_grad:
            (o_all, s_all, i_ct_all, w_ct_all) = loop_val
            assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1
            return (o_all, s_all, i_ct_all[0], w_ct_all)
        else:
            return loop_val
Esempio n. 30
0
 def new_weights(self, input_signature):
     del input_signature
     return (np.zeros((), dtype=np.float32), )