コード例 #1
0
ファイル: transformer_layers.py プロジェクト: minhtcai/mesh
  def compute_mask(self, context, memory_position):
    """Compute attention mask.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
    Returns:
      a Tensor or None
    """
    masks = []
    min_relative_position = self.min_relative_position(context)
    max_relative_position = self.max_relative_position(context)
    if max_relative_position is not None or min_relative_position is not None:
      relative_position = memory_position - context.position
      if min_relative_position is not None:
        illegal = mtf.less(relative_position, min_relative_position)
        masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9)
      if max_relative_position is not None:
        illegal = mtf.greater(relative_position, max_relative_position)
        masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9)
    if (context.sequence_id is not None and
        isinstance(context.sequence_id, mtf.Tensor) and
        context.length_dim in context.sequence_id.shape):
      masks.append(mtf.cast(
          mtf.not_equal(
              context.sequence_id,
              self.rename_length_to_memory_length(
                  context.sequence_id, context)),
          context.activation_dtype) * -1e9)
    return mtf.add_n(masks) if masks else None
コード例 #2
0
def clip_by_global_norm(grads, clip_norm):
    """Clip the grads by global norm."""
    global_norm = mtf.sqrt(
        mtf.add_n(
            [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
    multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
    clipped_grads = [None if t is None else t * multiplier for t in grads]
    return clipped_grads, global_norm
コード例 #3
0
ファイル: transformer.py プロジェクト: minhtcai/mesh
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    sequence_id=None,
                    position=None,
                    encoder_output=None,
                    encoder_sequence_id=None,
                    shared_params=None):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
        For training autoregressive models this should be equal to
        mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      sequence_id: an optional Tensor
      position: an optional Tensor
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      shared_params: an optional dictionary

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        context = Context(mesh=inputs.mesh,
                          batch_dims=inputs.shape.dims[:-1],
                          length_dim=inputs.shape.dims[-1],
                          model_dim=self.model_dim,
                          variable_dtype=variable_dtype,
                          mode=mode,
                          autoregressive=self.autoregressive,
                          losses=[] if compute_loss else None,
                          sequence_id=sequence_id,
                          position=position,
                          encoder_output=encoder_output,
                          encoder_sequence_id=encoder_sequence_id,
                          shared_params=shared_params,
                          layout=self.layout,
                          mesh_shape=self.mesh_shape)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context, inputs, targets)
        if compute_loss:
            loss = mtf.add_n(context.losses)
        else:
            loss = None
        return logits, loss
コード例 #4
0
 def talking_heads(
     self, context, inp, name, input_heads_dims, output_heads_dims,
     dynamic_projections_from=None):
   shared_dims = [d for d in input_heads_dims if d in output_heads_dims]
   reduced_dims = [d for d in input_heads_dims if d not in output_heads_dims]
   new_dims = [d for d in output_heads_dims if d not in input_heads_dims]
   if not (reduced_dims or new_dims):
     # Output dimensions are same as input dimensions.  Return the input
     return inp
   elif dynamic_projections_from:
     # There are one or more dynamic talking-heads-projections
     with tf.variable_scope(name):
       # static projection - this is the same as the static projection in the
       # "else" case below.  We create the weight matrix with get_variable
       # instead of calling mtf.layers.dense() so that we can fold the
       # static projection into one of the dynamic projections.
       static_p_initializer = mtf.layers.VarianceScalingInitializer()(
           reduced_dims, new_dims)
       static_p_shape = (
           context.model.ensemble_dims + shared_dims + reduced_dims + new_dims)
       static_p = mtf.get_variable(inp.mesh,
                                   "kernel",
                                   static_p_shape,
                                   initializer=static_p_initializer,
                                   dtype=context.variable_dtype)
       ps = []
       for i, dp_from in enumerate(dynamic_projections_from):
         kernel_initializer = mtf.layers.VarianceScalingInitializer(
             self.dynamic_projections_init_scale
             / mtf.Shape(reduced_dims).size)
         ps.append(
             mtf.layers.dense(
                 dp_from, reduced_dims=[context.model.model_dim],
                 new_dims=shared_dims + reduced_dims + new_dims,
                 use_bias=False, activation=None,
                 variable_dtype=context.variable_dtype,
                 name="%s_dynamic_%d" % (name, i),
                 expert_dims=context.model.ensemble_dims,
                 kernel_initializer=kernel_initializer))
       # Fold the static projection into one of the static projections.
       # Mathematically, we could add all the dynamic projections together
       #   here, but it would create a very large tensor which contained
       #   both the query-length and memory-length dimensions, and would
       #   probably be slower in practice.
       ps[0] += static_p
       return mtf.add_n(
           [mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps])
   else:
     # No dynamic projections.  Static talking-heads projection only
     return mtf.layers.dense(
         inp, reduced_dims=reduced_dims,
         new_dims=new_dims,
         use_bias=False, activation=None,
         variable_dtype=context.variable_dtype,
         name=name, expert_dims=context.model.ensemble_dims + shared_dims)
コード例 #5
0
ファイル: transformer_layers.py プロジェクト: mzj14/mesh
    def call(self, context, x, losses=None):
        """Call the layer."""
        wq, wk, wv, wo = mtf.layers.multihead_attention_params(
            context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
            context.variable_dtype)
        memory_length = mtf.Dimension("memory_length", context.length_dim.size)
        q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            m = x
        else:
            m = mtf.rename_dimension(x, context.length_dim.name,
                                     "memory_length")
        k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
        v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            old_k, old_v = context.get_states(2)
            one_hot = mtf.one_hot(context.position,
                                  memory_length,
                                  dtype=context.activation_dtype)
            inv_one_hot = 1.0 - one_hot
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot
        if context.mode == "incremental" or context.mode == "first_part":
            context.record_new_states([k, v])
        masks = []
        if context.autoregressive:
            masks.append(
                mtf.cast(
                    mtf.less(
                        context.position,
                        mtf.range(context.mesh, memory_length,
                                  dtype=tf.int32)), context.activation_dtype) *
                -1e9)
        if (context.sequence_id is not None
                and isinstance(context.sequence_id, mtf.Tensor)
                and context.length_dim in context.sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        context.sequence_id,
                        mtf.layers.rename_length_to_memory_length(
                            context.sequence_id)), context.activation_dtype) *
                -1e9)
        mask = mtf.add_n(masks) if masks else None

        o = mtf.layers.dot_product_attention_v2(
            q, k, v, memory_length, self.kv_dim, self.kv_dim, mask,
            self.dropout_rate if context.train else 0.0, [context.length_dim])
        return mtf.einsum([o, wo],
                          x.shape,
                          reduced_dims=[self.heads_dim, self.kv_dim])
コード例 #6
0
    def compute_mask(self, context, memory_position):
        """Compute attention mask.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
    Returns:
      a Tensor or None
    """
        masks = []
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        if max_relative_position is not None or min_relative_position is not None:
            relative_position = memory_position - context.position
            if min_relative_position is not None:
                illegal = mtf.less(relative_position, min_relative_position)
                masks.append(
                    mtf.cast(illegal, context.activation_dtype) * -1e9)
            if max_relative_position is not None:
                illegal = mtf.greater(relative_position, max_relative_position)
                masks.append(
                    mtf.cast(illegal, context.activation_dtype) * -1e9)
        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        sequence_id,
                        self.rename_length_to_memory_length(
                            sequence_id, context)), context.activation_dtype) *
                -1e9)
        return mtf.add_n(masks) if masks else None
コード例 #7
0
 def get_extra_loss(self):
   return mtf.add_n(self._extra_losses)
コード例 #8
0
  def __init__(self,
               config,
               is_training,
               input_ids,
               input_mask=None,
               token_type_ids=None,
               scope=None,
               mesh_shape="",
               layout=""):
    self.config = copy.deepcopy(config)
    del config
    if not is_training:
      self.config.layer_output_dropout_prob = 0.0
      self.config.attention_probs_dropout_prob = 0.0
      self.config.feedforward_intermediate_dropout_prob = 0.0
    input_shape = input_ids.shape
    assert input_shape.ndims == 2

    self._seq_dim = input_shape.dims[1]
    self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
    self._extra_losses = []
    mesh = input_ids.mesh

    if token_type_ids is None:
      token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

    with tf.variable_scope(scope, default_name="bert"):
      with tf.variable_scope("embeddings"):
        # Perform embedding lookup on the word ids.
        self.embedding_table = mtf.get_variable(
            mesh, "word_embeddings",
            mtf.Shape([self.vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        self.word_embedding_output = mtf.gather(
            self.embedding_table, input_ids, self.vocab_dim)

        # Add positional embeddings and token type embeddings, then layer
        # normalize and perform dropout.
        self.embedding_output = self.word_embedding_output

        token_type_table = mtf.get_variable(
            mesh, "token_type_embeddings",
            mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        if token_type_ids is not None:
          self.embedding_output += mtf.gather(
              token_type_table, token_type_ids, self.token_type_vocab_dim)
        if self.config.position_signal == "embedding":
          full_position_table = mtf.get_variable(
              mesh, "position_embeddings",
              mtf.Shape([self.max_position_embeddings_dim, self.model_dim]),
              initializer=self.embedding_initializer)
          short_position_table = mtf.rename_dimension(
              mtf.slice(full_position_table, 0, self.seq_dim.size,
                        self.max_position_embeddings_dim.name),
              self.max_position_embeddings_dim.name, self.seq_dim.name)
          self.embedding_output += short_position_table
        self.embedding_output = self.normalize(self.embedding_output)
        self.embedding_output = mtf.dropout(
            self.embedding_output, is_training,
            keep_prob=1.0 - self.config.layer_output_dropout_prob)

      with tf.variable_scope("encoder"):
        attention_biases = []
        if input_mask:
          # [batch_dim, memory_seq_dim]
          attention_biases.append(
              (1.0 - mtf.to_float(mtf.replace_dimensions(
                  input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0)
        if self.config.position_signal == "relative_attention_bias":
          buckets_dim = mtf.Dimension("buckets", 32)
          rp_bucket = _relative_position_bucket(
              mtf.range(mesh, self.memory_seq_dim, tf.int32)
              - mtf.range(mesh, self.seq_dim, tf.int32),
              num_buckets=buckets_dim.size)
          bias_var = mtf.get_variable(
              mesh, "relative_attention_bias",
              [self.num_heads_dim, buckets_dim],
              initializer=tf.zeros_initializer())
          attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim))
        attention_bias = mtf.add_n(attention_biases)
        prev_layer_output = self.embedding_output
        self.all_encoder_layers = []
        for block_num in range(self.config.num_blocks):
          with tf.variable_scope("block_%d" % block_num):
            for layer_idx, layer_type in enumerate(self.config.block_layers):
              layer_name = layer_type
              count = self.config.block_layers[:layer_idx].count(layer_type)
              if count:
                layer_name += "_%d" % count
              with tf.variable_scope(layer_name):
                x = prev_layer_output
                if self.config.residual_structure == "direct":
                  x = self.normalize(x)
                if layer_type == "attention":
                  x = self.self_attention(x, attention_bias)
                elif layer_type == "feedforward":
                  x = self.feedforward(x)
                elif layer_type == "moe":
                  x = self.moe(x, layout, mesh_shape, input_mask, is_training)
                else:
                  raise ValueError("unknown layer type " + layer_type)
                x = mtf.dropout(
                    x, is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)
                layer_output = prev_layer_output + x
                if self.config.residual_structure == "original":
                  layer_output = self.normalize(layer_output)
                prev_layer_output = layer_output
          self.all_encoder_layers.append(layer_output)

      self.sequence_output = prev_layer_output
      if self.config.residual_structure == "direct":
        self.sequence_output = self.normalize(self.sequence_output)

      # The "pooler" converts the encoded sequence tensor of shape
      # [batch_dim, seq_dim, hidden_size] to a tensor of shape
      # [batch_dim, hidden_size]. This is necessary for segment-level
      # (or segment-pair-level) classification tasks where we need a fixed
      # dimensional representation of the segment.
      with tf.variable_scope("pooler"):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. We assume that this has been pre-trained
        first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim)
        self.pooled_output = mtf.layers.dense(
            first_token_tensor,
            reduced_dims=[self.model_dim],
            new_dims=[self.model_dim],
            activation=mtf.tanh,
            kernel_initializer=self.dense_initializer,
            use_bias=self.config.use_bias)
コード例 #9
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          [heads_dim, buckets_dim],
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual")
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
 def call_simple(self,
                 inputs,
                 targets,
                 compute_loss,
                 attributes=None,
                 mode=tf.estimator.ModeKeys.TRAIN,
                 variable_dtype=mtf.VariableDType(tf.float32),
                 sequence_id=None,
                 subsequence_id=None,
                 position=None,
                 encoder_output=None,
                 encoder_sequence_id=None,
                 encoder_inputs=None,
                 shared_params=None,
                 layer_outputs=None,
                 encoder_layer_outputs=None,
                 z=None):
     """Compute logits based on inputs (all positions in parallel).
     This is called during training and evaluation.
     Args:
       inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training
         autoregressive models this should be equal to mtf.shift(targets,
         offset=1, dim=length_dim, wrap=False)
       targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
       compute_loss: a boolean
       attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
       mode: a tf.estimator.ModeKeys
       variable_dtype: a mtf.VariableDType
       sequence_id: an optional Tensor
       subsequence_id: an optional Tensor
       position: an optional Tensor
       encoder_output: an optional Tensor
       encoder_sequence_id: an optional Tensor
       encoder_inputs: an optional Tensor
       shared_params: an optional dictionary
       layer_outputs: an optional list to append Tensor layer activations to
       encoder_layer_outputs: optional - readonly list of tensor activations when
         decoding, one per each input layer + the embedding layer
     Returns:
       logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
       loss: an optional Scalar (if compute_loss=True)
     """
     batch_dims = inputs.shape.dims[:-1]
     length_dim = inputs.shape.dims[-1]
     length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
     if not self.positional_embedding:
         # To make relative attention faster, we drop the information about the
         #   position in the subsequence.  The relative attention code then
         #   assumes that the positions are given by index in the tensor,
         #   which still leads to the correct computation of relative position.
         position = None
     if position is None:
         position_is_default = True
         position = length_range
     else:
         position_is_default = False
     if self.input_full_attention:
         # The inputs part of each sequence can fully attend within itself.
         full_attention_region = delimited_lm_inputs_mask(targets)
         # We can include one additional position to the right - the position
         #   where the final EOS of the inputs is read and the first target token
         #   is predicted.
         full_attention_region = mtf.logical_or(
             full_attention_region,
             mtf.shift(full_attention_region,
                       offset=1,
                       dim=length_dim,
                       wrap=False))
         # We set read_priority and write_priority to 0 in the full-attention
         #   region and equal to the position elsewhere.
         read_priority = write_priority = length_range * mtf.cast(
             mtf.logical_not(full_attention_region), tf.int32)
     elif self.autoregressive:
         # Vanilla autoregressive model - each position can see previous positions.
         read_priority = write_priority = length_range
     else:
         read_priority = write_priority = None
     context = Context(model=self,
                       mesh=inputs.mesh,
                       batch_dims=batch_dims,
                       length_dim=length_dim,
                       variable_dtype=variable_dtype,
                       mode=mode,
                       losses=[] if compute_loss else None,
                       sequence_id=sequence_id,
                       subsequence_id=subsequence_id,
                       position=position,
                       position_is_default=position_is_default,
                       encoder_output=encoder_output,
                       encoder_sequence_id=encoder_sequence_id,
                       shared_params=shared_params,
                       layer_outputs=layer_outputs,
                       encoder_layer_outputs=encoder_layer_outputs,
                       write_priority=write_priority,
                       read_priority=read_priority,
                       inputs=inputs,
                       encoder_inputs=encoder_inputs)
     with tf.variable_scope(self.name):
         logits = self._call_internal(context,
                                      inputs,
                                      targets,
                                      attributes,
                                      z=z)
     if compute_loss:
         loss = mtf.add_n(context.losses)
     else:
         loss = None
     return logits, loss