コード例 #1
0
    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end, mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(ids, stop_at_token)),
                reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)
コード例 #2
0
 def cond_fn(position, ids, *unused_states):
   """Should we run another loop iteration."""
   past_end = mtf.greater_equal(position, length_dim.size)
   is_done = past_end
   if stop_at_token is not None:
     has_eos = mtf.reduce_any(
         mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
     is_done = mtf.logical_or(is_done, has_eos)
   all_done = mtf.reduce_all(is_done)
   return mtf.logical_not(all_done)
 def call_simple(self,
                 inputs,
                 targets,
                 compute_loss,
                 attributes=None,
                 mode=tf.estimator.ModeKeys.TRAIN,
                 variable_dtype=mtf.VariableDType(tf.float32),
                 sequence_id=None,
                 subsequence_id=None,
                 position=None,
                 encoder_output=None,
                 encoder_sequence_id=None,
                 encoder_inputs=None,
                 shared_params=None,
                 layer_outputs=None,
                 encoder_layer_outputs=None,
                 z=None):
     """Compute logits based on inputs (all positions in parallel).
     This is called during training and evaluation.
     Args:
       inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training
         autoregressive models this should be equal to mtf.shift(targets,
         offset=1, dim=length_dim, wrap=False)
       targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
       compute_loss: a boolean
       attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
       mode: a tf.estimator.ModeKeys
       variable_dtype: a mtf.VariableDType
       sequence_id: an optional Tensor
       subsequence_id: an optional Tensor
       position: an optional Tensor
       encoder_output: an optional Tensor
       encoder_sequence_id: an optional Tensor
       encoder_inputs: an optional Tensor
       shared_params: an optional dictionary
       layer_outputs: an optional list to append Tensor layer activations to
       encoder_layer_outputs: optional - readonly list of tensor activations when
         decoding, one per each input layer + the embedding layer
     Returns:
       logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
       loss: an optional Scalar (if compute_loss=True)
     """
     batch_dims = inputs.shape.dims[:-1]
     length_dim = inputs.shape.dims[-1]
     length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
     if not self.positional_embedding:
         # To make relative attention faster, we drop the information about the
         #   position in the subsequence.  The relative attention code then
         #   assumes that the positions are given by index in the tensor,
         #   which still leads to the correct computation of relative position.
         position = None
     if position is None:
         position_is_default = True
         position = length_range
     else:
         position_is_default = False
     if self.input_full_attention:
         # The inputs part of each sequence can fully attend within itself.
         full_attention_region = delimited_lm_inputs_mask(targets)
         # We can include one additional position to the right - the position
         #   where the final EOS of the inputs is read and the first target token
         #   is predicted.
         full_attention_region = mtf.logical_or(
             full_attention_region,
             mtf.shift(full_attention_region,
                       offset=1,
                       dim=length_dim,
                       wrap=False))
         # We set read_priority and write_priority to 0 in the full-attention
         #   region and equal to the position elsewhere.
         read_priority = write_priority = length_range * mtf.cast(
             mtf.logical_not(full_attention_region), tf.int32)
     elif self.autoregressive:
         # Vanilla autoregressive model - each position can see previous positions.
         read_priority = write_priority = length_range
     else:
         read_priority = write_priority = None
     context = Context(model=self,
                       mesh=inputs.mesh,
                       batch_dims=batch_dims,
                       length_dim=length_dim,
                       variable_dtype=variable_dtype,
                       mode=mode,
                       losses=[] if compute_loss else None,
                       sequence_id=sequence_id,
                       subsequence_id=subsequence_id,
                       position=position,
                       position_is_default=position_is_default,
                       encoder_output=encoder_output,
                       encoder_sequence_id=encoder_sequence_id,
                       shared_params=shared_params,
                       layer_outputs=layer_outputs,
                       encoder_layer_outputs=encoder_layer_outputs,
                       write_priority=write_priority,
                       read_priority=read_priority,
                       inputs=inputs,
                       encoder_inputs=encoder_inputs)
     with tf.variable_scope(self.name):
         logits = self._call_internal(context,
                                      inputs,
                                      targets,
                                      attributes,
                                      z=z)
     if compute_loss:
         loss = mtf.add_n(context.losses)
     else:
         loss = None
     return logits, loss
コード例 #4
0
ファイル: attention.py プロジェクト: qixiuai/mesh
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       attention_kwargs=None):
  """Attention to the a neighborood around the source.

  If autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not autoregressive, then query position p can only see memory positions
  in the range (p - window_size, p + radius].

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    attention_kwargs: optional keyword arguments for attention()

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if autoregressive else mtf.halo_exchange)(
        x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  illegal = mtf.not_equal(q_sequence_id, m_sequence_id)
  illegal = mtf.logical_or(illegal, mtf.less_equal(relative_position, -radius))
  illegal = mtf.logical_or(illegal, mtf.greater(
      relative_position, 0 if autoregressive else radius))
  mask = mtf.cast(illegal, q.dtype) * -1e9
  o = attention(q, k, v, padded_memory_block_length,
                key_dim, value_dim, mask, **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)