Beispiel #1
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   params = self.make_params(context)
   q = params.compute_q(x)
   memory_length = self.memory_length(context)
   if context.mode == "incremental":
     m = x
   else:
     m = mtf.replace_dimensions(x, context.length_dim, memory_length)
   if self.shared_kv:
     kv = params.compute_kv(m)
   else:
     k = params.compute_k(m)
     v = params.compute_v(m)
   if context.mode == "incremental":
     one_hot = mtf.one_hot(
         context.position, memory_length, dtype=context.activation_dtype)
     inv_one_hot = 1.0 - one_hot
     if self.shared_kv:
       old_kv = context.get_states(1)
       kv = old_kv * inv_one_hot + kv * one_hot
     else:
       old_k, old_v = context.get_states(2)
       k = old_k * inv_one_hot + k * one_hot
       v = old_v * inv_one_hot + v * one_hot
     memory_position = mtf.range(context.mesh, memory_length, tf.int32)
   else:
     memory_position = self.rename_length_to_memory_length(
         context.position, context)
   if context.mode == "incremental" or context.mode == "first_part":
     context.record_new_states([kv] if self.shared_kv else [k, v])
   if self.shared_kv:
     k = kv
     v = kv
   if self.attention_func == "hybrid":
     o = attention.hybrid_attention(
         q, k, v, context,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   else:
     o = attention.attention(
         q, k, v,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         self.compute_bias(
             context, memory_position, x, params.query_heads_dims),
         **self.attention_kwargs_from_context(context))
   return params.compute_output(o, output_shape=x.shape)
Beispiel #2
0
def enc_dec_attention(self_attention_layer, memory_antecedent, context, x,
                      losses):
  """Multi-head attention over the encoder outputs."""
  memory_input_dim = memory_antecedent.shape[-1]
  if memory_input_dim != context.model.model_dim:
    raise NotImplementedError(
        "TODO(noam): support different model_dim in encoder and decoder.")
  params = self_attention_layer.make_params(context)
  q = params.compute_q(x)
  if context.mode == "incremental":
    k, v, memory_length = context.get_constant_state()
  else:
    m = memory_antecedent
    if self_attention_layer.shared_kv:
      kv = params.compute_kv(m)
      k = kv
      v = kv
    else:
      k = params.compute_k(m)
      v = params.compute_v(m)
    memory_length, = [d for d in m.shape.dims if d.name == "memory_length"]
    if context.mode == "first_part":
      context.record_constant_state((k, v, memory_length))
  if context.encoder_sequence_id and context.sequence_id:
    visible = mtf.equal(context.sequence_id, context.encoder_sequence_id)
    bias = attention.visibility_mask_to_attention_bias(visible,
                                                       context.activation_dtype)
  else:
    bias = None
  a = attention.attention(
      q, k, v, memory_length, self_attention_layer.kv_dim,
      self_attention_layer.kv_dim, bias,
      **self_attention_layer.attention_kwargs_from_context(context))
  attention_output_shape = self_attention_layer.expected_attention_output_shape(
      x, params)
  attention_output = params.compute_output(
      a, output_shape=attention_output_shape)
  return self_attention_layer.layer_output_from_attention_output(
      context, attention_output, losses)
Beispiel #3
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     memory_antecedent = self._get_memory_antecedent(context)
     memory_input_dim = memory_antecedent.shape[-1]
     if memory_input_dim != context.model.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     params = self.make_params(context)
     q = params.compute_q(x)
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         m = memory_antecedent
         if self.shared_kv:
             kv = params.compute_kv(m)
             k = kv
             v = kv
         else:
             k = params.compute_k(m)
             v = params.compute_v(m)
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         visible = mtf.equal(context.sequence_id,
                             context.encoder_sequence_id)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
     else:
         bias = None
     o = attention.attention(q, k, v, memory_length, self.kv_dim,
                             self.kv_dim, bias,
                             **self.attention_kwargs_from_context(context))
     return params.compute_output(o, output_shape=x.shape)
Beispiel #4
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     if context.mode == "incremental":
         if self.shared_kv:
             prev_kv, = context.get_states(1)
         else:
             prev_k, prev_v = context.get_states(2)
         current_position = mtf.equal(
             mtf.range(context.mesh, self.window_dim, dtype=tf.int32),
             mtf.mod(context.position, self.radius))
         if self.shared_kv:
             kv = mtf.where(current_position,
                            kv,
                            prev_kv,
                            output_shape=prev_kv.shape)
             k = kv
             v = kv
             context.record_new_states([kv])
         else:
             k = mtf.where(current_position,
                           params.compute_k(x),
                           prev_k,
                           output_shape=prev_k.shape)
             v = mtf.where(current_position,
                           params.compute_v(x),
                           prev_v,
                           output_shape=prev_v.shape)
             context.record_new_states([k, v])
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         visible = mtf.greater_equal(context.position, window_pos)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
         o = attention.attention(
             q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias,
             **self.attention_kwargs_from_context(context))
     elif context.length_dim.size <= max(256, self.radius * 4):
         # nothing fancy - just do full attention and mask
         memory_length = self.rename_length_to_memory_length(
             context.position, context)
         o = attention.attention(
             q, self.rename_length_to_memory_length(k, context),
             self.rename_length_to_memory_length(v, context),
             self.memory_length(context), self.kv_dim, self.kv_dim,
             self.compute_bias(context, memory_length, x),
             **self.attention_kwargs_from_context(context))
     else:
         # fancy local attention algorithm
         o = attention.local_attention_1d(
             q=q,
             k=k,
             v=None if self.shared_kv else v,
             length_dim=context.length_dim,
             key_dim=self.kv_dim,
             value_dim=self.kv_dim,
             length_dim_num_splits=1,  # TODO(noam): look at the layout
             autoregressive=context.model.fully_autoregressive,
             radius=self.radius,
             sequence_id=context.sequence_id,
             write_priority=context.write_priority,
             read_priority=context.read_priority,
             attention_kwargs=self.attention_kwargs_from_context(context))
     if context.mode == "first_part":
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         pos = mtf.range(context.mesh, context.length_dim, tf.int32)
         select_recent = mtf.cast(
             mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype)
         select_recent *= mtf.cast(mtf.less(pos, context.initial_position),
                                   x.dtype)
         select_recent *= mtf.cast(
             mtf.greater_equal(pos, context.initial_position - self.radius),
             x.dtype)
         state_shape = (k.shape - [context.length_dim, self.kv_dim] +
                        [self.window_dim, self.kv_dim])
         k_state = mtf.einsum([k, select_recent],
                              output_shape=state_shape,
                              reduced_dims=[context.length_dim])
         context.new_states.append(k_state)
         if not self.shared_kv:
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.append(v_state)
     return params.compute_output(o, output_shape=x.shape)
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None,
                       context=None):
    """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()
    context: optional context.

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
    # Choose a suitable block size.
    # We choose the greatest divisor of length_per_split less than or equal
    # to max(window_size, 128)
    tf.logging.info(attention_kwargs)
    length_per_split = length_dim.size // length_dim_num_splits
    block_length = max(radius, 128)
    while length_per_split % block_length != 0:
        block_length -= 1
    query_block_length = mtf.Dimension("query_block_length", block_length)
    memory_block_length = mtf.Dimension("memory_block_length", block_length)
    # The num_blocks dimension gets the same name as the length dimension,
    # so it will be split in the same way.
    num_blocks = mtf.Dimension(length_dim.name,
                               length_dim.size // block_length)

    def _reshape_query(x):
        return mtf.replace_dimensions(x, length_dim,
                                      [num_blocks, query_block_length])

    def _reshape_memory(x):
        x = mtf.replace_dimensions(x, length_dim,
                                   [num_blocks, memory_block_length])
        return (mtf.left_halo_exchange if fully_autoregressive else
                mtf.halo_exchange)(x, num_blocks, memory_block_length, radius)

    q = _reshape_query(q)
    k = _reshape_memory(k)
    if v:
        v = _reshape_memory(v)
    else:
        v = k
    if sequence_id is None:
        sequence_id = 1
    if (not isinstance(sequence_id, mtf.Tensor)
            or length_dim not in sequence_id.shape.dims):
        sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
    q_sequence_id = _reshape_query(sequence_id)
    m_sequence_id = _reshape_memory(sequence_id)
    pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
    q_pos = _reshape_query(pos)
    m_pos = _reshape_memory(pos)

    padded_memory_block_length = mtf.Dimension(
        "memory_block_length",
        (1 if fully_autoregressive else 2) * radius + block_length)

    relative_position = m_pos - q_pos
    visible = mtf.equal(q_sequence_id, m_sequence_id)
    visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
    visible = mtf.logical_and(
        visible,
        mtf.less_equal(relative_position,
                       0 if fully_autoregressive else radius))
    if read_priority is not None:
        write_priority = _reshape_memory(write_priority)
        read_priority = _reshape_query(read_priority)
        visible = mtf.logical_and(
            visible, mtf.greater_equal(read_priority, write_priority))

    bias = attention.visibility_mask_to_attention_bias(visible, q.dtype)
    o = attention.attention(q,
                            k,
                            v,
                            padded_memory_block_length,
                            key_dim,
                            value_dim,
                            bias,
                            context=context,
                            **attention_kwargs)
    return mtf.replace_dimensions(o, [num_blocks, query_block_length],
                                  length_dim)