Exemple #1
0
  def _compute_merge_qkv(self, antecedent):
    """Computes qkv all in one call using MoE layer."""
    def _replace_d_model_dim(t):
      """Used to replace the `d_model` dim with `heads`."""
      new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
      return mtf.reshape(
          t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))
    if self.expert_computation == "qkv":
      # NOTE: This assumes querty and memory antecedent are the same
      qk = self.moe_layer.call(self.context, antecedent)
      # Split qk here since they went through experts-layers
      q, k = qk
      q = _replace_d_model_dim(q)
      k = _replace_d_model_dim(k)
    elif self.expert_computation == "q":
      q = self.moe_layer.call(self.context, antecedent)
      q = _replace_d_model_dim(q)
      # Compute key/value normally
      k = mtf.layers.us_einsum(
          [antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
    elif self.expert_computation == "kv":
      k = self.moe_layer.call(self.context, antecedent)
      k = _replace_d_model_dim(k)
      # Compute query normally
      q = mtf.layers.us_einsum(
          [antecedent, self.wq], reduced_dims=[self.query_input_dim])
    else:
      raise ValueError("Invalid expert computation mode: {}".format(
          self.expert_computation))

    # Scale query
    q *= self.key_dim.size ** -0.5
    self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
    self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
Exemple #2
0
    def get_indices(self, keys: mtf.Tensor,
                    query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]:
        """Generate score and indices for the query."""
        score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3])
        scores = mtf.einsum([query, keys],
                            output_shape=score_shape)  # [b, l, h, 2, n_keys]
        knn_dim = mtf.Dimension("knn", self.knn)
        scores, indices = mtf.top_k(scores, score_shape.dims[-1],
                                    knn_dim)  # [b, l, h, 2, knn]

        # Computes the top cartesian products and their indices
        knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2)
        scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2])
        scores2 = mtf.rename_dimension(scores2, "knn", "knn2")
        out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:])
        all_scores = mtf.add(scores1, scores2, output_shape=out_shape)
        all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:],
                                            knn_square_dim)

        indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2])
        indices1 = mtf.multiply(indices1, self.n_keys)
        indices2 = mtf.rename_dimension(indices2, "knn", "knn2")
        all_indices = mtf.add(indices1, indices2, output_shape=out_shape)
        all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:],
                                             knn_square_dim)

        scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1],
                                         knn_dim)
        return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
Exemple #3
0
    def mdha_shared_qk(self, query_antecedent, context):
        """MDHA QK shared projection."""
        ret = mtf.layers.us_einsum([query_antecedent, self.wq],
                                   reduced_dims=[self.query_input_dim])
        with tf.variable_scope("qk_dconv"):
            len_dim = context.length_dim
            context.length_dim = ret.shape.dims[-2]
            ret = causal_depthwise_conv(ret, context=context, kernel_size=3)
            context.length_dim = len_dim

        q = mtf.layers.dense(ret,
                             ret.shape.dims[-1:],
                             use_bias=False,
                             activation=None,
                             variable_dtype=context.variable_dtype,
                             reduced_dims=ret.shape.dims[-1:],
                             name="q_solo_project",
                             expert_dims=context.model.ensemble_dims)

        k = ret

        if self.combine_dims:
            q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
            k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
        if not self.fold_scaling_into_initializer:
            q *= self.key_dim.size**-0.5

        return q, k
  def _compute_merge_qkv(self, antecedent):
    """Computes qkv all in one call using MoE layer."""
    # NOTE: This assumes querty and memory antecedent are the same
    qk = self.moe_layer.call(self.context, antecedent)
    # Split qk here since they went through experts-layers
    q, k = qk

    # Scale query
    q *= self.key_dim.size ** -0.5
    self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
    self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
Exemple #5
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   memory_length = self.memory_length(context)
   q = self.compute_q(context, x)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   k = self.compute_k(context, m)
   v = self.compute_v(context, m)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     old_k, old_v = context.get_states(2)
     k = old_k * inv_one_hot + k * one_hot
     v = old_v * inv_one_hot + v * one_hot
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([k, v])
   bias = self.compute_bias(context, memory_position, x,
                            self.softmax_heads_dims, q)
   return self.attention_internal(context, x, m, q, k, v, memory_length, bias)
Exemple #6
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   params = self.make_params(context)
   q = params.compute_q(x)
   memory_length = self.memory_length(context)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   if self.shared_kv:
     kv = params.compute_kv(m)
   else:
     k = params.compute_k(m)
     v = params.compute_v(m)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     if self.shared_kv:
       old_kv = context.get_states(1)
       kv = old_kv * inv_one_hot + kv * one_hot
     else:
       old_k, old_v = context.get_states(2)
       k = old_k * inv_one_hot + k * one_hot
       v = old_v * inv_one_hot + v * one_hot
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([kv] if self.shared_kv else [k, v])
   if self.shared_kv:
     k = kv
     v = kv
   if self.attention_func == "hybrid":
     o = attention.hybrid_attention(
         q, k, v, context,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   else:
     o = attention.attention(
         q, k, v,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   return params.compute_output(o, output_shape=x.shape)
Exemple #7
0
 def mdha_v(self, memory_antecedent, context):
     """MDHA V projection."""
     ret = mtf.layers.us_einsum([memory_antecedent, self.wv],
                                reduced_dims=[self.memory_input_dim])
     with tf.variable_scope("v_dconv"):
         len_dim = context.length_dim
         context.length_dim = ret.shape.dims[-2]
         ret = causal_depthwise_conv(ret, context=context, kernel_size=3)
         context.length_dim = len_dim
     if self.combine_dims:
         ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims)
     return ret
Exemple #8
0
 def mdha_q(self, query_antecedent, context):
     """MDHA Q projection."""
     ret = mtf.layers.us_einsum([query_antecedent, self.wq],
                                reduced_dims=[self.query_input_dim])
     with tf.variable_scope("q_dconv"):
         len_dim = context.length_dim
         context.length_dim = ret.shape.dims[-2]
         ret = causal_depthwise_conv(ret, context=context, kernel_size=3)
         context.length_dim = len_dim
     if self.combine_dims:
         ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims)
     if not self.fold_scaling_into_initializer:
         ret *= self.key_dim.size**-0.5
     return ret
Exemple #9
0
    def compute_q(self, query_antecedent):
        """Compute query Tensor q.

    Args:
      query_antecedent: a Tensor with dimensions
         {query_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
         query_heads_dims + {key_dim} + other_dims
    """
        ret = mtf.einsum([query_antecedent, self.wq],
                         reduced_dims=[self.query_input_dim])
        if self.combine_dims:
            ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims)
        return ret
Exemple #10
0
    def compute_kv(self, memory_antecedent):
        """Compute key/value Tensor kv.

    Args:
      memory_antecedent: a Tensor with dimensions
        {memory_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
        memory_heads_dims + {key_dim} + other_dims
    """
        if not self.shared_kv:
            raise ValueError("compute_kv can only be called with shared_kv")
        ret = mtf.einsum([memory_antecedent, self.wkv],
                         reduced_dims=[self.memory_input_dim])
        if self.combine_dims:
            ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.k_dims)
        return ret
  def compute_q(self, query_antecedent):
    """Compute query Tensor q.

    Args:
      query_antecedent: a Tensor with dimensions
         {query_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
         query_heads_dims + {key_dim} + other_dims
    """
    ret = mtf.layers.us_einsum(
        [query_antecedent, self.wq], reduced_dims=[self.query_input_dim])
    if self.combine_dims:
      ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims)
    if not self.fold_scaling_into_initializer:
      ret *= self.key_dim.size ** -0.5
    return ret
  def compute_v(self, memory_antecedent):
    """Compute value Tensor v.

    Args:
      memory_antecedent: a Tensor with dimensions
        {memory_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
        memory_heads_dims + {value_dim} + other_dims
    """
    if self.shared_kv:
      raise ValueError("compute_v cannot be called with shared_kv")
    ret = mtf.layers.us_einsum(
        [memory_antecedent, self.wv], reduced_dims=[self.memory_input_dim])
    if self.combine_dims:
      ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims)
    return ret
Exemple #13
0
  def compute_output(self, o, output_shape=None):
    """Compute output of multihead attention.

    Args:
      o: a Tensor with dimensions
         query_heads_dims + {value_dim} + other_dims
      output_shape: an optional Shape
    Returns:
      a Tensor with shape:
         {output_dim} + other_dims
    """
    if self.combine_dims:
      o = mtf.transpose(o, o.shape - self.o_dims + self.o_dims)
      o = mtf.replace_dimensions(o, self.o_dims, self.wo.shape.dims[-2])
      reduced_dims = [self.wo.shape.dims[-2]]
    else:
      reduced_dims = self.o_dims
    return mtf.einsum(
        [o, self.wo], output_shape=output_shape, reduced_dims=reduced_dims)
Exemple #14
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   memory_length = self.memory_length(context)
   q = self.compute_q(context, x)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     old_m, = context.get_states(1)
     m = old_m * inv_one_hot + one_hot * m
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([m])
   bias = self.compute_bias(context, memory_position, x, self.heads_dims, q)
   return self.attention_internal(context, q, m, memory_length, bias)
  def self_attention(self, x, attention_bias):
    """Performs multi-headed self-attention with output projection.

    Args:
      x: output of previous layer
      attention_bias: optional float32 Tensor broadcastable to shape
        x.shape - self.model_dim + self.memory_seq_dim
        to be added to attention logits.
        This may used to mask out padding regions of the memory.

    Returns:
      float Tensor with the same shape as x
    """

    queries = mtf.layers.dense(
        x,
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="query",
        use_bias=self.config.use_bias)
    keys = mtf.layers.dense(
        mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="key",
        use_bias=self.config.use_bias)
    values = mtf.layers.dense(
        mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="value",
        use_bias=self.config.use_bias)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    attention_scores = mtf.einsum(
        [queries, keys], reduced_dims=[self.size_per_head_dim])
    attention_scores *= self.size_per_head_dim.size ** -0.5

    if attention_bias is not None:
      attention_scores += attention_bias

    # Normalize the attention scores to probabilities.
    attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = mtf.dropout(
        attention_probs,
        is_training=(self.config.attention_probs_dropout_prob == 0.0),
        keep_prob=1.0 - self.config.attention_probs_dropout_prob)

    output = mtf.einsum([attention_probs, values],
                        reduced_dims=[self.memory_seq_dim])

    # linear transformation back to shape of query_antecedent
    output = mtf.layers.dense(
        output,
        reduced_dims=[self.num_heads_dim, self.size_per_head_dim],
        new_dims=[self.model_dim],
        kernel_initializer=self.dense_initializer,
        name="output",
        use_bias=self.config.use_bias)
    output = mtf.transpose(output, x.shape)
    return output
  def __init__(self,
               config,
               is_training,
               input_ids,
               input_mask=None,
               token_type_ids=None,
               scope=None,
               mesh_shape="",
               layout=""):
    self.config = copy.deepcopy(config)
    del config
    if not is_training:
      self.config.layer_output_dropout_prob = 0.0
      self.config.attention_probs_dropout_prob = 0.0
      self.config.feedforward_intermediate_dropout_prob = 0.0
    input_shape = input_ids.shape
    assert input_shape.ndims == 2

    self._seq_dim = input_shape.dims[1]
    self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
    self._extra_losses = []
    mesh = input_ids.mesh

    if token_type_ids is None:
      token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

    with tf.variable_scope(scope, default_name="bert"):
      with tf.variable_scope("embeddings"):
        # Perform embedding lookup on the word ids.
        self.embedding_table = mtf.get_variable(
            mesh, "word_embeddings",
            mtf.Shape([self.vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        self.word_embedding_output = mtf.gather(
            self.embedding_table, input_ids, self.vocab_dim)

        # Add positional embeddings and token type embeddings, then layer
        # normalize and perform dropout.
        self.embedding_output = self.word_embedding_output

        token_type_table = mtf.get_variable(
            mesh, "token_type_embeddings",
            mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        if token_type_ids is not None:
          self.embedding_output += mtf.gather(
              token_type_table, token_type_ids, self.token_type_vocab_dim)
        if self.config.position_signal == "embedding":
          full_position_table = mtf.get_variable(
              mesh, "position_embeddings",
              mtf.Shape([self.max_position_embeddings_dim, self.model_dim]),
              initializer=self.embedding_initializer)
          short_position_table = mtf.rename_dimension(
              mtf.slice(full_position_table, 0, self.seq_dim.size,
                        self.max_position_embeddings_dim.name),
              self.max_position_embeddings_dim.name, self.seq_dim.name)
          self.embedding_output += short_position_table
        self.embedding_output = self.normalize(self.embedding_output)
        self.embedding_output = mtf.dropout(
            self.embedding_output, is_training,
            keep_prob=1.0 - self.config.layer_output_dropout_prob)

      with tf.variable_scope("encoder"):
        attention_biases = []
        if input_mask:
          # [batch_dim, memory_seq_dim]
          attention_biases.append(
              (1.0 - mtf.to_float(mtf.replace_dimensions(
                  input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0)
        if self.config.position_signal == "relative_attention_bias":
          buckets_dim = mtf.Dimension("buckets", 32)
          rp_bucket = _relative_position_bucket(
              mtf.range(mesh, self.memory_seq_dim, tf.int32)
              - mtf.range(mesh, self.seq_dim, tf.int32),
              num_buckets=buckets_dim.size)
          bias_var = mtf.get_variable(
              mesh, "relative_attention_bias",
              [self.num_heads_dim, buckets_dim],
              initializer=tf.zeros_initializer())
          attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim))
        attention_bias = mtf.add_n(attention_biases)
        prev_layer_output = self.embedding_output
        self.all_encoder_layers = []
        for block_num in range(self.config.num_blocks):
          with tf.variable_scope("block_%d" % block_num):
            for layer_idx, layer_type in enumerate(self.config.block_layers):
              layer_name = layer_type
              count = self.config.block_layers[:layer_idx].count(layer_type)
              if count:
                layer_name += "_%d" % count
              with tf.variable_scope(layer_name):
                x = prev_layer_output
                if self.config.residual_structure == "direct":
                  x = self.normalize(x)
                if layer_type == "attention":
                  x = self.self_attention(x, attention_bias)
                elif layer_type == "feedforward":
                  x = self.feedforward(x)
                elif layer_type == "moe":
                  x = self.moe(x, layout, mesh_shape, input_mask, is_training)
                else:
                  raise ValueError("unknown layer type " + layer_type)
                x = mtf.dropout(
                    x, is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)
                layer_output = prev_layer_output + x
                if self.config.residual_structure == "original":
                  layer_output = self.normalize(layer_output)
                prev_layer_output = layer_output
          self.all_encoder_layers.append(layer_output)

      self.sequence_output = prev_layer_output
      if self.config.residual_structure == "direct":
        self.sequence_output = self.normalize(self.sequence_output)

      # The "pooler" converts the encoded sequence tensor of shape
      # [batch_dim, seq_dim, hidden_size] to a tensor of shape
      # [batch_dim, hidden_size]. This is necessary for segment-level
      # (or segment-pair-level) classification tasks where we need a fixed
      # dimensional representation of the segment.
      with tf.variable_scope("pooler"):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. We assume that this has been pre-trained
        first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim)
        self.pooled_output = mtf.layers.dense(
            first_token_tensor,
            reduced_dims=[self.model_dim],
            new_dims=[self.model_dim],
            activation=mtf.tanh,
            kernel_initializer=self.dense_initializer,
            use_bias=self.config.use_bias)
Exemple #17
0
 def rename_length_to_memory_length(self, x, context):
     return mtf.replace_dimensions(x, context.length_dim,
                                   self.memory_length(context))
 def _my_reshape(x):
   if x and resplittable_dim in x.shape.dims:
     return mtf.replace_dimensions(
         x, resplittable_dim, [new_dim_high, new_dim_low])
   else:
     return x
 def _reshape_memory(x):
   x = mtf.replace_dimensions(
       x, length_dim, [num_blocks, memory_block_length])
   return (mtf.left_halo_exchange if fully_autoregressive
           else mtf.halo_exchange)(
               x, num_blocks, memory_block_length, radius)
 def _reshape_query(x):
   return mtf.replace_dimensions(
       x, length_dim, [num_blocks, query_block_length])
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None):
  """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if fully_autoregressive
            else mtf.halo_exchange)(
                x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if fully_autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  visible = mtf.equal(q_sequence_id, m_sequence_id)
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
  visible = mtf.logical_and(visible, mtf.less_equal(
      relative_position, 0 if fully_autoregressive else radius))
  if read_priority is not None:
    write_priority = _reshape_memory(write_priority)
    read_priority = _reshape_query(read_priority)
    visible = mtf.logical_and(
        visible, mtf.greater_equal(read_priority, write_priority))

  bias = visibility_mask_to_attention_bias(visible, q.dtype)
  o = attention(q, k, v, padded_memory_block_length,
                key_dim, value_dim, bias, **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
Exemple #22
0
 def combine_batch_dims(self, x):
     if len(self.batch_dims) <= 1:
         return x
     return mtf.replace_dimensions(x, self.batch_dims,
                                   mtf.combined_dimension(self.batch_dims))
Exemple #23
0
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None):
    # x :: [batch, seq, n_embd]
    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh

    # n_state is the same as config["n_embd"], which is also the same as dim_embd.
    assert n_state.size % params["n_head"] == 0

    dim_heads = mtf.Dimension("heads", params["n_head"])

    num_mem_kv = params.get("num_mem_kv", 0)
    use_num_mem_kv = num_mem_kv > 0

    with tf.variable_scope(scope):
        # Compute attention inputs
        dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
        mtfparams = mtf.transformer.attention.attention_params_simple(
            x.mesh,
            io_dim=dim_embd,
            kv_dim=dim_kv,
            heads_dim=dim_heads,
            variable_dtype=variable_dtype
        )
        q = mtfparams.compute_q(x)
        k = mtfparams.compute_k(x)
        v = mtfparams.compute_v(x)

        if is_incremental_inference(context):
            one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
            inv_one_hot = 1.0 - one_hot
            old_k, old_v = context.get_states(2)
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot

        if exists(context):
            context.record_new_states([k, v])

        with tf.variable_scope("attention"):
            if attention_type == "local":
                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                radius = params.get("local_attention_radius", 256)

                if is_incremental_inference(context):
                    q *= one_hot

                a = mtf_transformer.attention.local_attention_1d(
                    q, k, v,
                    length_dim=k.shape[1],
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    radius=radius,
                    length_dim_num_splits=1,
                    fully_autoregressive=params["causal"],
                    attention_kwargs={},
                )

                if is_incremental_inference(context):
                    a = mtf.gather(a, context.position - 1, dim_seq)

            elif attention_type == "global":

                # TODO: pass in fake context
                # Broadcast mask bias across batch and heads
                if exists(bias):
                    if not is_incremental_inference(context):
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
                    else:
                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                        bias = mtf.gather(bias, context.position - 1, dim_seq)
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])

                # memory key / values, from all-attention paper
                if use_num_mem_kv:
                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)

                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)

                attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0

                a = mtf_transformer.attention.attention(
                    q, k, v,
                    memory_length_dim=memory_length_dim,
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    bias=broadcasted_bias,
                    dropout_rate=attn_dropout_rate
                )

            elif attention_type == "linear":
                linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
                a = linear_attn_fn(q, k, v)

            else:
                raise NotImplementedError("Unknown attention type {}!".format(attention_type))

        with tf.variable_scope("compute_output"):
            a = mtfparams.compute_output(a, x_shape)

        with tf.variable_scope("compute_output_bias"):
            b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
                                 master_dtype=variable_dtype.master_dtype,
                                 slice_dtype=variable_dtype.slice_dtype,
                                 activation_dtype=variable_dtype.activation_dtype)
            a += b

        if params["mode"] == "train" and params["res_dropout"] > 0:
            a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
        return a
Exemple #24
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     if self.share_qk_rep:
         q, k = params.mdha_shared_qk(x, context)
     else:
         q = params.mdha_q(x, context)
     memory_length = self.memory_length(context)
     if context.mode == "incremental":
         m = x
     else:
         if self.share_qk_rep:
             k = mtf.replace_dimensions(k, context.length_dim,
                                        memory_length)
         m = mtf.replace_dimensions(x, context.length_dim, memory_length)
     if self.shared_kv:
         kv = params.compute_kv(m)
     else:
         if not self.share_qk_rep:
             k = params.mdha_k(m, context)
         v = params.mdha_v(m, context)
     if context.mode == "incremental":
         one_hot = mtf.one_hot(context.position,
                               memory_length,
                               dtype=context.activation_dtype)
         inv_one_hot = 1.0 - one_hot
         if self.shared_kv:
             old_kv = context.get_states(1)
             kv = old_kv * inv_one_hot + kv * one_hot
         else:
             old_k, old_v = context.get_states(2)
             k = old_k * inv_one_hot + k * one_hot
             v = old_v * inv_one_hot + v * one_hot
         memory_position = mtf.range(context.mesh, memory_length, tf.int32)
     else:
         memory_position = self.rename_length_to_memory_length(
             context.position, context)
     if context.mode == "incremental" or context.mode == "first_part":
         context.record_new_states([kv] if self.shared_kv else [k, v])
     if self.shared_kv:
         k = kv
         v = kv
     o = self.attention_fn(q,
                           k,
                           v,
                           context=context,
                           memory_length_dim=memory_length,
                           key_dim=self.kv_dim,
                           value_dim=self.kv_dim,
                           bias=self.compute_bias(context, memory_position,
                                                  x,
                                                  params.query_heads_dims,
                                                  q),
                           **self.attention_kwargs_from_context(context))
     attention_output_shape = self.expected_attention_output_shape(
         x, params)
     attention_output = params.compute_output(
         o, output_shape=attention_output_shape)
     return self.layer_output_from_attention_output(context,
                                                    attention_output,
                                                    losses)
Exemple #25
0
    def attention(self,
                  x,
                  n_state,
                  mask,
                  attention_type="global",
                  name="attn"):
        # x :: [batch, seq, n_embd]
        batch_dim, seq_dim, embd_dim = x_shape = x.shape
        assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
        with tf.variable_scope(name):
            # Compute attention inputs
            mtfparams = mtf.transformer.attention.attention_params_simple(
                x.mesh,
                io_dim=self.dimensions["embed_dim"],
                kv_dim=self.dimensions["kv_dim"],
                heads_dim=self.dimensions["heads_dim"],
                variable_dtype=self.variable_dtype)
            q = mtfparams.compute_q(x)
            k = mtfparams.compute_k(x)
            v = mtfparams.compute_v(x)

            if self.is_incremental_inference:
                one_hot = mtf.one_hot(self.context.position - 1,
                                      seq_dim,
                                      dtype=self.variable_dtype.master_dtype)
                inv_one_hot = 1.0 - one_hot
                old_k, old_v = self.context.get_states(2)
                k = old_k * inv_one_hot + k * one_hot
                v = old_v * inv_one_hot + v * one_hot

            if exists(self.context):
                self.context.record_new_states([k, v])

            with tf.variable_scope("attention"):
                if attention_type == "local":
                    # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                    radius = self.params.get("local_attention_radius", 256)
                    if self.is_incremental_inference:
                        q *= one_hot
                    a = mtf_transformer.attention.local_attention_1d(
                        q,
                        k,
                        v,
                        length_dim=k.shape[1],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        radius=radius,
                        length_dim_num_splits=1,
                        fully_autoregressive=True,
                        attention_kwargs={},
                    )
                    if self.is_incremental_inference:
                        a = mtf.gather(a, self.context.position - 1, seq_dim)

                elif attention_type == "global":
                    if exists(mask):
                        if not self.is_incremental_inference:
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-2], mask.shape[-1]
                                ])  # TODO: not sure this is correct
                        else:
                            # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                            mask = mtf.gather(mask, self.context.position - 1,
                                              seq_dim)
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-1]
                                ])

                    k = mtf.replace_dimensions(
                        k, k.shape[1], self.dimensions["memory_len_dim"])
                    v = mtf.replace_dimensions(
                        v, v.shape[1], self.dimensions["memory_len_dim"])

                    attn_dropout_rate = self.params.get(
                        "attention_dropout", 0) if self.mode == "train" else 0

                    a = mtf_transformer.attention.attention(
                        q,
                        k,
                        v,
                        memory_length_dim=self.dimensions["memory_len_dim"],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        bias=broadcasted_mask,
                        dropout_rate=attn_dropout_rate)
                else:
                    raise NotImplementedError(
                        "Unknown attention type {}!".format(attention_type))

            with tf.variable_scope("compute_output"):
                a = mtfparams.compute_output(a, x_shape)

            with tf.variable_scope("compute_output_bias"):
                b = mtf.get_variable(
                    x.mesh,
                    "o_b", [embd_dim],
                    initializer=tf.constant_initializer(0),
                    master_dtype=self.variable_dtype.master_dtype,
                    slice_dtype=self.variable_dtype.slice_dtype,
                    activation_dtype=self.variable_dtype.activation_dtype)
                a += b
            residual_dropout = self.params.get("residual_dropout", 0)
            if self.mode == "train" and residual_dropout > 0:
                a = mtf.dropout(a, rate=residual_dropout, name="res_dropout")
            return a