コード例 #1
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     m = self._get_memory_antecedent(context)
     memory_input_dim = m.shape[-1]
     if memory_input_dim != context.model.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     q = self.compute_q(context, x)
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         k = self.compute_k(context, m)
         v = self.compute_v(context, m)
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         visible = mtf.equal(context.sequence_id,
                             context.encoder_sequence_id)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
     else:
         bias = None
     return self.attention_internal(context, x, m, q, k, v, memory_length,
                                    bias)
コード例 #2
0
def enc_dec_attention(self_attention_layer, memory_antecedent, context, x,
                      losses):
  """Multi-head attention over the encoder outputs."""
  memory_input_dim = memory_antecedent.shape[-1]
  if memory_input_dim != context.model.model_dim:
    raise NotImplementedError(
        "TODO(noam): support different model_dim in encoder and decoder.")
  params = self_attention_layer.make_params(context)
  q = params.compute_q(x)
  if context.mode == "incremental":
    k, v, memory_length = context.get_constant_state()
  else:
    m = memory_antecedent
    if self_attention_layer.shared_kv:
      kv = params.compute_kv(m)
      k = kv
      v = kv
    else:
      k = params.compute_k(m)
      v = params.compute_v(m)
    memory_length, = [d for d in m.shape.dims if d.name == "memory_length"]
    if context.mode == "first_part":
      context.record_constant_state((k, v, memory_length))
  if context.encoder_sequence_id and context.sequence_id:
    visible = mtf.equal(context.sequence_id, context.encoder_sequence_id)
    bias = attention.visibility_mask_to_attention_bias(visible,
                                                       context.activation_dtype)
  else:
    bias = None
  a = attention.attention(
      q, k, v, memory_length, self_attention_layer.kv_dim,
      self_attention_layer.kv_dim, bias,
      **self_attention_layer.attention_kwargs_from_context(context))
  attention_output_shape = self_attention_layer.expected_attention_output_shape(
      x, params)
  attention_output = params.compute_output(
      a, output_shape=attention_output_shape)
  return self_attention_layer.layer_output_from_attention_output(
      context, attention_output, losses)
コード例 #3
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     memory_antecedent = self._get_memory_antecedent(context)
     memory_input_dim = memory_antecedent.shape[-1]
     if memory_input_dim != context.model.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     params = self.make_params(context)
     q = params.compute_q(x)
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         m = memory_antecedent
         if self.shared_kv:
             kv = params.compute_kv(m)
             k = kv
             v = kv
         else:
             k = params.compute_k(m)
             v = params.compute_v(m)
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         visible = mtf.equal(context.sequence_id,
                             context.encoder_sequence_id)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
     else:
         bias = None
     o = attention.attention(q, k, v, memory_length, self.kv_dim,
                             self.kv_dim, bias,
                             **self.attention_kwargs_from_context(context))
     return params.compute_output(o, output_shape=x.shape)
コード例 #4
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     if context.mode == "incremental":
         if self.shared_kv:
             prev_kv, = context.get_states(1)
         else:
             prev_k, prev_v = context.get_states(2)
         current_position = mtf.equal(
             mtf.range(context.mesh, self.window_dim, dtype=tf.int32),
             mtf.mod(context.position, self.radius))
         if self.shared_kv:
             kv = mtf.where(current_position,
                            kv,
                            prev_kv,
                            output_shape=prev_kv.shape)
             k = kv
             v = kv
             context.record_new_states([kv])
         else:
             k = mtf.where(current_position,
                           params.compute_k(x),
                           prev_k,
                           output_shape=prev_k.shape)
             v = mtf.where(current_position,
                           params.compute_v(x),
                           prev_v,
                           output_shape=prev_v.shape)
             context.record_new_states([k, v])
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         visible = mtf.greater_equal(context.position, window_pos)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
         o = attention.attention(
             q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias,
             **self.attention_kwargs_from_context(context))
     elif context.length_dim.size <= max(256, self.radius * 4):
         # nothing fancy - just do full attention and mask
         memory_length = self.rename_length_to_memory_length(
             context.position, context)
         o = attention.attention(
             q, self.rename_length_to_memory_length(k, context),
             self.rename_length_to_memory_length(v, context),
             self.memory_length(context), self.kv_dim, self.kv_dim,
             self.compute_bias(context, memory_length, x),
             **self.attention_kwargs_from_context(context))
     else:
         # fancy local attention algorithm
         o = attention.local_attention_1d(
             q=q,
             k=k,
             v=None if self.shared_kv else v,
             length_dim=context.length_dim,
             key_dim=self.kv_dim,
             value_dim=self.kv_dim,
             length_dim_num_splits=1,  # TODO(noam): look at the layout
             autoregressive=context.model.fully_autoregressive,
             radius=self.radius,
             sequence_id=context.sequence_id,
             write_priority=context.write_priority,
             read_priority=context.read_priority,
             attention_kwargs=self.attention_kwargs_from_context(context))
     if context.mode == "first_part":
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         pos = mtf.range(context.mesh, context.length_dim, tf.int32)
         select_recent = mtf.cast(
             mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype)
         select_recent *= mtf.cast(mtf.less(pos, context.initial_position),
                                   x.dtype)
         select_recent *= mtf.cast(
             mtf.greater_equal(pos, context.initial_position - self.radius),
             x.dtype)
         state_shape = (k.shape - [context.length_dim, self.kv_dim] +
                        [self.window_dim, self.kv_dim])
         k_state = mtf.einsum([k, select_recent],
                              output_shape=state_shape,
                              reduced_dims=[context.length_dim])
         context.new_states.append(k_state)
         if not self.shared_kv:
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.append(v_state)
     return params.compute_output(o, output_shape=x.shape)
コード例 #5
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          [heads_dim, buckets_dim],
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual")
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
コード例 #6
0
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None,
                       context=None):
    """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()
    context: optional context.

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
    # Choose a suitable block size.
    # We choose the greatest divisor of length_per_split less than or equal
    # to max(window_size, 128)
    tf.logging.info(attention_kwargs)
    length_per_split = length_dim.size // length_dim_num_splits
    block_length = max(radius, 128)
    while length_per_split % block_length != 0:
        block_length -= 1
    query_block_length = mtf.Dimension("query_block_length", block_length)
    memory_block_length = mtf.Dimension("memory_block_length", block_length)
    # The num_blocks dimension gets the same name as the length dimension,
    # so it will be split in the same way.
    num_blocks = mtf.Dimension(length_dim.name,
                               length_dim.size // block_length)

    def _reshape_query(x):
        return mtf.replace_dimensions(x, length_dim,
                                      [num_blocks, query_block_length])

    def _reshape_memory(x):
        x = mtf.replace_dimensions(x, length_dim,
                                   [num_blocks, memory_block_length])
        return (mtf.left_halo_exchange if fully_autoregressive else
                mtf.halo_exchange)(x, num_blocks, memory_block_length, radius)

    q = _reshape_query(q)
    k = _reshape_memory(k)
    if v:
        v = _reshape_memory(v)
    else:
        v = k
    if sequence_id is None:
        sequence_id = 1
    if (not isinstance(sequence_id, mtf.Tensor)
            or length_dim not in sequence_id.shape.dims):
        sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
    q_sequence_id = _reshape_query(sequence_id)
    m_sequence_id = _reshape_memory(sequence_id)
    pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
    q_pos = _reshape_query(pos)
    m_pos = _reshape_memory(pos)

    padded_memory_block_length = mtf.Dimension(
        "memory_block_length",
        (1 if fully_autoregressive else 2) * radius + block_length)

    relative_position = m_pos - q_pos
    visible = mtf.equal(q_sequence_id, m_sequence_id)
    visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
    visible = mtf.logical_and(
        visible,
        mtf.less_equal(relative_position,
                       0 if fully_autoregressive else radius))
    if read_priority is not None:
        write_priority = _reshape_memory(write_priority)
        read_priority = _reshape_query(read_priority)
        visible = mtf.logical_and(
            visible, mtf.greater_equal(read_priority, write_priority))

    bias = attention.visibility_mask_to_attention_bias(visible, q.dtype)
    o = attention.attention(q,
                            k,
                            v,
                            padded_memory_block_length,
                            key_dim,
                            value_dim,
                            bias,
                            context=context,
                            **attention_kwargs)
    return mtf.replace_dimensions(o, [num_blocks, query_block_length],
                                  length_dim)