Ejemplo n.º 1
 def _reshape_to_batch_and_copy_targets(preds, targets):
   batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]])
   batched_targets = jnp.reshape(targets, [-1])
   return batched_preds, batched_targets, targets
Ejemplo n.º 2
def MergeLastTwoAxes():
    return tl.Fn('SplitLastAxis',
                 lambda x: jnp.reshape(x, x.shape[:-2] + (-1, )))
Ejemplo n.º 3
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

      Tensor of same shape and dtype as the input.
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Ejemplo n.º 4
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

      Tensor of same shape and dtype as the input.
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Ejemplo n.º 5
def SplitLastAxis(num_splits):
    return tl.Fn(f'SplitLastAxis_{num_splits}',
                 lambda x: jnp.reshape(x, x.shape[:-1] + (num_splits, -1)))
Ejemplo n.º 6
 def reshape_to_chunks(x):
   batch_times_length = x.shape[0] * x.shape[1]
   assert batch_times_length % chunk_size == 0
   n_chunks = batch_times_length // chunk_size
   return jnp.reshape(x, [n_chunks, 1, chunk_size] + list(x.shape[2:]))
Ejemplo n.º 7
def _fast_inference_update_state(inputs, state, mask_for_predict=None):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores arrays with cached values of keys and values,
  as well as an index. To make shapes static, keys and values in the state are
  long, and the index indicates where the new keys and values from inputs need
  to be appended.

  During update, we append new_keys and new_values to keys and values at
  position given by index. And we increment index by length of new keys.
  We also create a mask to be 1 at appropriate positions (causal mask).

    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, index)
    mask_for_predict: mask used for predict mode. This is used only in

    Updated state and mask to be used.
    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    if mask_for_predict is not None:
        (state_mask_for_predict, ks, vs, idx) = state
        (ks, vs, idx) = state
    length = new_k.shape[1]
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    k_length = ks.shape[1]

    # Mask is of shape [1, q_length, k_length].
    # Mask should be true for every pair of (query_token, key_token) such that
    # index of query_token is equal or larger to index of key_token.
    mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape(
        jnp.arange(length) + idx, (1, length, 1)))
    if mask_for_predict is None:
        return (ks, vs, idx + length), mask
        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0,
            mask_for_predict.reshape((-1)) != 0,

        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0,
            jnp.ones((1, )) != 0,
            jnp.sum(mask_for_predict, dtype=jnp.int32),

        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0, jnp.ones((1, )) != 0, idx, axis=0)
        placeholder = jnp.reshape(state_mask_for_predict != 0, (
        mask = mask * placeholder

        return (state_mask_for_predict, ks, vs, idx + length), mask
Ejemplo n.º 8
 def f(x):
   if len(x.shape) < 2:
     return x  # No extra batch dimension: use devices as batch, so return.
   batch_size = x.shape[0] * x.shape[1]
   return jnp.reshape(x, [batch_size] + list(x.shape[2:]))
Ejemplo n.º 9
 def _unflatten_inputs(self, inputs):
     return jnp.reshape(
         inputs, inputs.shape[:-1] + self._shape + (self._n_categories, ))
Ejemplo n.º 10
 def serialize(x):
   (batch_size, length) = x.shape[:2]
   shape_suffix = x.shape[2:]
   x = jnp.reshape(x, (batch_size * length,) + shape_suffix)
   x = serializer.serialize(x)
   return jnp.reshape(x, (batch_size, -1, serializer.representation_length,))
Ejemplo n.º 11
def _InsertAxes12():
    """Returns a layer that inserts two internal size-1 axes into an array."""
    return tl.Fn('InsertAxes12',
                 lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))
Ejemplo n.º 12
 def f(x):
   return jnp.reshape(x, x.shape[:2] + (n_controls, n_actions))
Ejemplo n.º 13
def MergeLastTwoAxes():  # pylint: disable=invalid-name
    return tl.Fn('SplitLastAxis', lambda x: np.reshape(x, x.shape[:-2] +
                                                       (-1, )))
Ejemplo n.º 14
def SplitLastAxis(num_splits):  # pylint: disable=invalid-name
    return tl.Fn(f'SplitLastAxis_{num_splits}',
                 lambda x: np.reshape(x, x.shape[:-1] + (num_splits, -1)))
Ejemplo n.º 15
 def _reshape_xent_back(xent, targets):
   return jnp.reshape(xent, targets.shape)
Ejemplo n.º 16
 def reshape_mask(mask):
     return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
Ejemplo n.º 17
def ReformerShortenLM(vocab_size,
  """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape ::

      [batch, length // F, d_model]

  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

    the layer.
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(vocab_size, d_embedding),
      tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn('Shorten', lambda x: jnp.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.Select([0], n_in=2),
      tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn('ProlongBack', lambda x: jnp.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
Ejemplo n.º 18
 def reshape_from_chunks(x):
     batch_size = x.shape[0] * x.shape[1]
     return jnp.reshape(x, [batch_size] + list(x.shape[2:]))
Ejemplo n.º 19
 def f(x):  # pylint: disable=invalid-name
   in_rank = len(x.shape)
   if in_rank <= n_axes_to_keep:
     raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
                      f'axes to keep ({n_axes_to_keep}) after flattening.')
   return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
Ejemplo n.º 20
 def compute_attention_output(x):
     # Data reshaping for the model layers
     seqlen = x.shape[1]
     x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
     x = jnp.transpose(x, (0, 2, 1, 3))
     return jnp.reshape(x, (-1, seqlen, n_heads * d_head))