Esempio n. 1
0
 def cond_fn(position, ids, *unused_states):
   """Should we run another loop iteration."""
   past_end = mtf.greater_equal(position, length_dim.size)
   is_done = past_end
   if stop_at_token is not None:
     has_eos = mtf.reduce_any(
         mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
     is_done = mtf.logical_or(is_done, has_eos)
   all_done = mtf.reduce_all(is_done)
   return mtf.logical_not(all_done)
def visibility_mask_to_attention_bias(visible, dtype):
  """Convert a boolean visibility mask to an attention bias.

  The returned Tensor has large negative values in positions where
  visible=False.

  Args:
    visible: a boolean Tensor
    dtype: a dtype
  Returns:
    a Tensor with the given dtype and the same shape as "visible"
  """
  return mtf.cast(mtf.logical_not(visible), dtype) * -1e9
Esempio n. 3
0
    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end, mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(ids, stop_at_token)),
                reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)
 def call_simple(self,
                 inputs,
                 targets,
                 compute_loss,
                 attributes=None,
                 mode=tf.estimator.ModeKeys.TRAIN,
                 variable_dtype=mtf.VariableDType(tf.float32),
                 sequence_id=None,
                 subsequence_id=None,
                 position=None,
                 encoder_output=None,
                 encoder_sequence_id=None,
                 encoder_inputs=None,
                 shared_params=None,
                 layer_outputs=None,
                 encoder_layer_outputs=None,
                 z=None):
     """Compute logits based on inputs (all positions in parallel).
     This is called during training and evaluation.
     Args:
       inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training
         autoregressive models this should be equal to mtf.shift(targets,
         offset=1, dim=length_dim, wrap=False)
       targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
       compute_loss: a boolean
       attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
       mode: a tf.estimator.ModeKeys
       variable_dtype: a mtf.VariableDType
       sequence_id: an optional Tensor
       subsequence_id: an optional Tensor
       position: an optional Tensor
       encoder_output: an optional Tensor
       encoder_sequence_id: an optional Tensor
       encoder_inputs: an optional Tensor
       shared_params: an optional dictionary
       layer_outputs: an optional list to append Tensor layer activations to
       encoder_layer_outputs: optional - readonly list of tensor activations when
         decoding, one per each input layer + the embedding layer
     Returns:
       logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
       loss: an optional Scalar (if compute_loss=True)
     """
     batch_dims = inputs.shape.dims[:-1]
     length_dim = inputs.shape.dims[-1]
     length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
     if not self.positional_embedding:
         # To make relative attention faster, we drop the information about the
         #   position in the subsequence.  The relative attention code then
         #   assumes that the positions are given by index in the tensor,
         #   which still leads to the correct computation of relative position.
         position = None
     if position is None:
         position_is_default = True
         position = length_range
     else:
         position_is_default = False
     if self.input_full_attention:
         # The inputs part of each sequence can fully attend within itself.
         full_attention_region = delimited_lm_inputs_mask(targets)
         # We can include one additional position to the right - the position
         #   where the final EOS of the inputs is read and the first target token
         #   is predicted.
         full_attention_region = mtf.logical_or(
             full_attention_region,
             mtf.shift(full_attention_region,
                       offset=1,
                       dim=length_dim,
                       wrap=False))
         # We set read_priority and write_priority to 0 in the full-attention
         #   region and equal to the position elsewhere.
         read_priority = write_priority = length_range * mtf.cast(
             mtf.logical_not(full_attention_region), tf.int32)
     elif self.autoregressive:
         # Vanilla autoregressive model - each position can see previous positions.
         read_priority = write_priority = length_range
     else:
         read_priority = write_priority = None
     context = Context(model=self,
                       mesh=inputs.mesh,
                       batch_dims=batch_dims,
                       length_dim=length_dim,
                       variable_dtype=variable_dtype,
                       mode=mode,
                       losses=[] if compute_loss else None,
                       sequence_id=sequence_id,
                       subsequence_id=subsequence_id,
                       position=position,
                       position_is_default=position_is_default,
                       encoder_output=encoder_output,
                       encoder_sequence_id=encoder_sequence_id,
                       shared_params=shared_params,
                       layer_outputs=layer_outputs,
                       encoder_layer_outputs=encoder_layer_outputs,
                       write_priority=write_priority,
                       read_priority=read_priority,
                       inputs=inputs,
                       encoder_inputs=encoder_inputs)
     with tf.variable_scope(self.name):
         logits = self._call_internal(context,
                                      inputs,
                                      targets,
                                      attributes,
                                      z=z)
     if compute_loss:
         loss = mtf.add_n(context.losses)
     else:
         loss = None
     return logits, loss