Exemple #1
0
    def testMultiheadAttention(self, kv_channels, heads):
        batch = 2
        length = 8
        channels = 3
        query = tf.random_normal([batch, length, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)
        channels_dim = mtf.Dimension("channels", channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
        mtf_outputs = mtf_layers.multihead_attention(
            mtf_query,
            memory_antecedent=None,
            mask=None,
            kv_channels=kv_channels_dim,
            heads=heads_dim)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            sess.run(tf_group)
            actual = sess.run(actual_outputs)

        self.assertEqual(actual.shape, query.shape)
Exemple #2
0
  def _layer_stack(self,
                   x,
                   num_layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
      num_layers: an integer
      encoder_output: an optional mtf.Tensor with shape
        [batch_dim, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))
    num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        x += layer_prepostprocess_dropout(
            mtf_layers.multihead_attention(
                normalize(x), None,
                self_attention_mask, self.kv_dim, self.heads_dim,
                dropout=hparams.attention_dropout,
                dropout_broadcast_dims=[self.length_dim],
                name="self_attention"))
        if encoder_output is not None:
          # Encoder-Decoder attention layer
          x += layer_prepostprocess_dropout(
              mtf_layers.multihead_attention(
                  normalize(x), encoder_output,
                  encdec_attention_mask, self.kv_dim, self.heads_dim,
                  dropout=hparams.attention_dropout,
                  dropout_broadcast_dims=[self.length_dim],
                  name="encdec_attention"))
        # ffn layer
        x += layer_prepostprocess_dropout(
            self._feedforward_layer(normalize(x), losses=losses))
    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    return x