Example #1
0
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
    """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name + "/layer_norm"):
        scale = mtf.get_variable(x.mesh,
                                 "layer_norm_scale",
                                 mtf.TensorShape([dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "layer_norm_bias",
                                mtf.TensorShape([dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)
        reduced_shape = x.shape - dim
        mean = mtf.reduce_mean(x, output_shape=reduced_shape)
        variance = mtf.reduce_mean(mtf.square(x - mean),
                                   output_shape=reduced_shape)
        norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
Example #2
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        mask,
                        kv_channels,
                        heads,
                        dropout=0.0,
                        dropout_broadcast_dims=None,
                        name="multihead_attention"):
    """Multihead scaled-dot-product attention with input/output transformations.

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional)
    mask: mask Tensor (see attention_mask())
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    dropout: a floating point value
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, query_length, io_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
    batch, query_length, io_channels = query_antecedent.shape.dims
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)
        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        o = dot_product_attention(q, k, v, mask, dropout,
                                  dropout_broadcast_dims)
        return mtf.einsum([o, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
Example #3
0
 def _embedding_and_softmax_vars(self, mesh):
   hparams = self._hparams
   targets_embedding_var = mtf.get_variable(
       mesh, "targets_embedding",
       mtf.TensorShape([self.targets_vocab_dim, self.model_dim]),
       initializer=tf.random_normal_initializer(),
       activation_dtype=self.activation_dtype)
   if self.has_input:
     if hparams.shared_embedding:
       inputs_embedding_var = targets_embedding_var
     else:
       inputs_embedding_var = mtf.get_variable(
           mesh, "inputs_embedding",
           mtf.TensorShape([self.inputs_vocab_dim, self.model_dim]),
           initializer=tf.random_normal_initializer(),
           activation_dtype=self.activation_dtype)
   else:
     inputs_embedding_var = None
   if hparams.shared_embedding_and_softmax_weights:
     softmax_var = targets_embedding_var * (self.model_dim.size ** -0.5)
   else:
     softmax_var = mtf.get_variable(
         mesh,
         "softmax",
         mtf.TensorShape([self.targets_vocab_dim, self.model_dim]),
         initializer=tf.random_normal_initializer(
             stddev=self.model_dim.size**-0.5),
         activation_dtype=self.activation_dtype)
   positional_embedding_var = mtf.get_variable(
       mesh, "positional_embedding",
       mtf.TensorShape([self.max_length_dim, self.model_dim]),
       initializer=tf.random_normal_initializer(),
       activation_dtype=self.activation_dtype)
   return (inputs_embedding_var, targets_embedding_var,
           softmax_var, positional_embedding_var)
Example #4
0
def multihead_self_attention_incremental(query_antecedent,
                                         prev_k,
                                         prev_v,
                                         step_num,
                                         name="multihead_attention"):
    """Incremental self-attention (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
    batch_dims = query_antecedent.shape.dims[:-1]
    io_channels = query_antecedent.shape.dims[-1]
    heads, memory_length, kv_channels = prev_k.shape.dims[-3:]
    with tf.variable_scope(name, default_name="multihead_attention"):
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)
        memory_antecedent = query_antecedent
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(batch_dims + [heads, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(batch_dims + [heads, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(batch_dims + [heads, kv_channels]))
        k = prev_k + mtf.multiply(
            k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape)
        v = prev_v + mtf.multiply(
            v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape)

        mask = mtf.to_float(
            mtf.greater(
                mtf.range(query_antecedent.mesh, memory_length,
                          dtype=tf.int32), step_num)) * -1e9
        o = dot_product_attention(q, k, v, mask)
        y = mtf.einsum([o, o_var], query_antecedent.shape)
        return y, k, v
Example #5
0
def dense(x,
          output_dim,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          name=None):
    """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
    if expert_dims is None:
        expert_dims = []
    if reduced_dims is None:
        reduced_dims = x.shape.dims[-1:]
    w_shape = mtf.TensorShape(expert_dims + reduced_dims + [output_dim])
    output_shape = mtf.TensorShape(
        [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])
    with tf.variable_scope(name, default_name="dense"):
        stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            w_shape,
            initializer=tf.random_normal_initializer(stddev=stddev),
            activation_dtype=x.dtype)
        y = mtf.matmul(x, w, output_shape=output_shape)
        if use_bias:
            b = mtf.get_variable(x.mesh,
                                 "bias",
                                 mtf.TensorShape(expert_dims + [output_dim]),
                                 initializer=tf.zeros_initializer(),
                                 activation_dtype=x.dtype)
            y += b
        if activation is not None:
            y = activation(y)
        return y
Example #6
0
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output
Example #7
0
def multihead_encdec_attention_incremental(query_antecedent,
                                           q_var,
                                           o_var,
                                           k,
                                           v,
                                           mask,
                                           name="multihead_attention"):
    """Incremental attention over encoder (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  memory_dims is a subset of query_dims

  Args:
    query_antecedent: a mtf.Tensor with shape query_dims + [io_channels]
    q_var: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    o_var: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    k: memory_dims + [heads, memory_length, kv_channels]
    v: memory_dims + [heads, memory_length, kv_channels]
    mask: mask Tensor (see attention_mask())
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, qlen, io_channels]
  """
    heads, _, kv_channels = k.shape.dims[-3:]
    query_dims = query_antecedent.shape.dims[:-1]
    with tf.variable_scope(name, default_name="multihead_attention"):
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(query_dims + [heads, kv_channels]))
        o = dot_product_attention(q, k, v, mask)
        return mtf.einsum([o, o_var], query_antecedent.shape)
Example #8
0
  def testLayerNorm(self):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.layer_norm(mtf_inputs,
                                        dim=channels_dim)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    expected_outputs = common_layers.layer_norm(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual, expected = sess.run([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
Example #9
0
  def testDense(self, units, use_bias):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    depth_dim = mtf.Dimension("depth", units)

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.dense(mtf_inputs,
                                   output_dim=depth_dim,
                                   reduced_dims=[channels_dim],
                                   activation=mtf.relu,
                                   use_bias=use_bias)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    expected_outputs = tf.keras.layers.Dense(units=units,
                                             activation=tf.nn.relu,
                                             use_bias=use_bias)(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual, expected = sess.run([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
Example #10
0
  def testDenseReluDense(self):
    batch = 2
    channels = 3
    hidden = 5
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    hidden_dim = mtf.Dimension("hidden", hidden)

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs,
                                              hidden_channels=hidden_dim)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, inputs.shape)
def _concat_equal_sizes(xs, dim, new_dim_name):
    axis = xs[0].shape.dims.index(dim)
    ret = mtf.stack(xs, "tmp_concat", axis)
    new_shape = mtf.TensorShape(
        xs[0].shape.dims[:axis] +
        [mtf.Dimension(new_dim_name, dim.size * len(xs))] +
        xs[0].shape.dims[axis + 1:])
    return mtf.reshape(ret, new_shape)
 def my_gather(tensor):
     return mtf.gather(tensor,
                       top_beam_index,
                       beam_dim,
                       output_shape=mtf.TensorShape([
                           double_beam if d == beam_dim else d
                           for d in tensor.shape.dims
                       ]))
Example #13
0
  def testDotProductAttention(
      self, batch, heads, length_q, length_kv, depth_k, depth_v):
    query = tf.random_normal([batch, heads, length_q, depth_k])
    key = tf.random_normal([batch, heads, length_kv, depth_k])
    value = tf.random_normal([batch, heads, length_kv, depth_v])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    heads_dim = mtf.Dimension("heads", heads)
    length_q_dim = mtf.Dimension("length_q", length_q)
    length_kv_dim = mtf.Dimension("length_kv", length_kv)
    depth_k_dim = mtf.Dimension("depth_k", depth_k)
    depth_v_dim = mtf.Dimension("depth_v", depth_v)

    mtf_query = mtf.infeed(
        mesh, query,
        shape=mtf.TensorShape(
            [batch_dim, heads_dim, length_q_dim, depth_k_dim]))
    mtf_key = mtf.infeed(
        mesh, key,
        shape=mtf.TensorShape(
            [batch_dim, heads_dim, length_kv_dim, depth_k_dim]))
    mtf_value = mtf.infeed(
        mesh, value,
        shape=mtf.TensorShape(
            [batch_dim, heads_dim, length_kv_dim, depth_v_dim]))
    mtf_outputs = mtf_layers.dot_product_attention(
        mtf_query,
        mtf_key,
        mtf_value,
        mask=None)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
 def gather(tensor, name):
     with tf.name_scope(prefix + name):
         output_shape = mtf.TensorShape([
             beam_dim if d == old_beam_dim else d for d in tensor.shape.dims
         ])
         return mtf.gather(tensor,
                           topk_indices,
                           old_beam_dim,
                           output_shape=output_shape)
Example #15
0
  def testMaskedLocalAttention1D(self, kv_channels, heads):
    batch = 2
    length_q = 16
    length_m = 16
    channels = 3
    query = tf.random_normal([batch, length_q, channels])
    memory = tf.random_normal([batch, length_m, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_q_dim = mtf.Dimension("length_q", length_q)
    length_m_dim = mtf.Dimension("length_m", length_m)
    channels_dim = mtf.Dimension("channels", channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.infeed(
        mesh, query,
        shape=mtf.TensorShape([batch_dim, length_q_dim, channels_dim]))
    mtf_memory = mtf.infeed(
        mesh, memory,
        shape=mtf.TensorShape([batch_dim, length_m_dim, channels_dim]))
    mtf_outputs = mtf_layers.masked_local_attention_1d(
        mtf_query,
        mtf_memory,
        kv_channels=kv_channels_dim,
        heads=heads_dim,
        block_length=2)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, (batch, length_q, channels))
Example #16
0
def toy_model(features, mesh):
    """A toy model implemented by mesh tensorlfow."""
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
    hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
    io_dim = mtf.Dimension('io', FLAGS.io_size)

    x = mtf.infeed(mesh, features, mtf.TensorShape([batch_dim, io_dim]))
    h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False)
    y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False)

    loss = mtf.reduce_sum(mtf.square(y - x))
    return y, loss
Example #17
0
def dot_product_attention(q,
                          k,
                          v,
                          mask,
                          dropout=0.0,
                          dropout_broadcast_dims=None):
    """Dot-product attention.

  Args:
    q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions
      are [batch, heads].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    mask: mask Tensor (see attention_mask())
    dropout: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    length_kv = k.shape.dims[-2]
    logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv])
    logits = mtf.einsum([q, k], logits_shape)
    if mask is not None:
        logits += mask
    weights = mtf.softmax(logits, length_kv)
    if dropout != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    depth_v = v.shape.dims[-1]
    outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v])
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Example #18
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    x = mtf.infeed(mesh, tf.reshape(image, [-1, 28, 28]),
                   mtf.TensorShape([batch_dim, rows_dim, cols_dim]))
    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=[rows_dim, cols_dim],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.infeed(mesh, labels, mtf.TensorShape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Example #19
0
    def mtf_model_fn(self, features, mesh):
        hparams = self._hparams
        # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size])
        tf_x = tf.matmul(
            tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size),
                       [hparams.batch_size, 1]),
            tf.reshape(tf.lin_space(0., 1.0, hparams.io_size),
                       [1, hparams.io_size]))
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        io_dim = mtf.Dimension("io", hparams.io_size)

        x = mtf.infeed_fully_replicated(mesh, tf_x,
                                        mtf.TensorShape([batch_dim, io_dim]))
        h = mtf_layers.dense(x, hidden_dim, name="layer1", use_bias=False)
        y = mtf_layers.dense(h, io_dim, name="layer2", use_bias=False)
        loss = mtf.reduce_sum(mtf.square(y - x))
        return None, loss
Example #20
0
def multihead_attention_vars(mesh, heads, io_channels, kv_channels,
                             activation_dtype):
    """Create Parameters for Multihead Attention.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    activation_dtype: a tf.dtype

  Returns:
    q_var: a Tensor with shape [heads, io_channels, kv_channels]
    k_var: a Tensor with shape [heads, io_channels, kv_channels]
    v_var: a Tensor with shape [heads, io_channels, kv_channels]
    o_var: a Tensor with shape [heads, io_channels, kv_channels]
  """
    qkvo = mtf.Dimension("qkvo", 4)
    qk_stddev = (io_channels.size**-0.5) * (kv_channels.size**-0.25)
    v_stddev = io_channels.size**-0.5
    o_stddev = (io_channels.size * heads.size)**-0.5

    def qkvo_initializer(shape,
                         dtype=None,
                         partition_info=None,
                         verify_shape=None):
        del partition_info, verify_shape
        return tf.random_normal(shape, dtype=dtype) * tf.reshape(
            [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1])

    var = mtf.get_variable(mesh,
                           "qkvo",
                           mtf.TensorShape(
                               [qkvo, heads, io_channels, kv_channels]),
                           initializer=qkvo_initializer,
                           activation_dtype=activation_dtype)
    q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo)
    return q_var, k_var, v_var, o_var
Example #21
0
  def testWeightsNonzero(self):
    inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
    channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    expected_outputs = common_layers.weights_nonzero(inputs)
    tf_group = lowering.copy_masters_to_slices()
    with self.test_session() as sess:
      sess.run(tf_group)
      actual, expected = sess.run([actual_outputs, expected_outputs])

    self.assertAllEqual(actual, expected)
Example #22
0
  def testMultiheadAttention(self, kv_channels, heads):
    batch = 2
    length = 8
    channels = 3
    query = tf.random_normal([batch, length, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)
    channels_dim = mtf.Dimension("channels", channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.infeed(
        mesh, query,
        shape=mtf.TensorShape([batch_dim, length_dim, channels_dim]))
    mtf_outputs = mtf_layers.multihead_attention(
        mtf_query,
        memory_antecedent=None,
        mask=None,
        kv_channels=kv_channels_dim,
        heads=heads_dim)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, query.shape)
Example #23
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
    """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
    with tf.variable_scope(name, default_name="dense_relu_dense"):
        io_channels = x.shape.dims[-1]
        stddev = (hidden_channels.size * io_channels.size)**-0.25
        io = mtf.Dimension("io", 2)
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            mtf.TensorShape([io, io_channels, hidden_channels]),
            initializer=tf.random_normal_initializer(stddev=stddev),
            activation_dtype=x.dtype)
        wi, wo = mtf.unstack(w, io)
        h = mtf.relu(mtf.einsum([x, wi]))
        if dropout != 0.0:
            h = mtf.dropout(h,
                            1.0 - dropout,
                            noise_shape=h.shape - dropout_broadcast_dims)
        return mtf.einsum([h, wo])
    def grow_topk(i, alive_seq, alive_log_probs, states=None):
        r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch, beam, length]
      alive_log_probs: probabilities of these sequences. [batch, beam]
      states: optional list of mtf.Tensor
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         list of transformed decoding states)
    """
        logits, new_states = logits_fn(i, alive_seq, states)
        batch_dim, beam_dim, vocab_dim = logits.shape.dims

        # Convert logits to normalized log probs
        candidate_log_probs = mtf.log_softmax(logits, vocab_dim)

        # Multiply the probabilities by the current probabilities of the beam.
        # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
        log_probs = candidate_log_probs + alive_log_probs

        length_penalty = mtf.pow(((5. + mtf.to_float(i + 1)) / 6.), alpha)

        curr_scores = log_probs / length_penalty

        # scores have shape [batch, beam, vocab]
        beam_and_vocab_dim = mtf.Dimension("beam_and_vocab",
                                           beam_dim.size * vocab_dim.size)
        flat_shape = mtf.TensorShape([batch_dim, beam_and_vocab_dim])
        double_beam = mtf.Dimension("double_beam", beam_dim.size * 2)
        # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
        flat_curr_scores = mtf.reshape(curr_scores, flat_shape)

        top_ids, top_scores = mtf.top_k(flat_curr_scores,
                                        reduced_dim=beam_and_vocab_dim,
                                        new_dim=double_beam)

        # Recovering the log probs because we will need to send them back
        top_log_probs = top_scores * length_penalty

        # Work out what beam the top probs are in.
        top_beam_index = top_ids // vocab_dim.size
        top_ids %= vocab_dim.size  # Unflatten the ids

        def my_gather(tensor):
            return mtf.gather(tensor,
                              top_beam_index,
                              beam_dim,
                              output_shape=mtf.TensorShape([
                                  double_beam if d == beam_dim else d
                                  for d in tensor.shape.dims
                              ]))

        # Gather up the most probable 2*beams both for the ids and finished_in_alive
        # bools
        top_seq = my_gather(alive_seq)

        if states:
            states = [my_gather(state) for state in new_states]

        # Append the most probable alive
        top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32)
        top_finished = mtf.equal(top_ids, eos_id)

        return top_seq, top_log_probs, top_scores, top_finished, states
Example #25
0
def moe_v0(inputs,
           hidden_dim,
           output_dim,
           experts_dim,
           loss_coef=1e-3,
           overhead=1.0):
    """Local mixture of experts that works well on TPU.

  See https://arxiv.org/abs/1701.06538

  There are num_experts expert networks, each containing a relu-activated
  hidden layer of size hidden_size, followed by an output projection.

  The number of parameters is thus:
    num_experts * (input_size * hidden_size + hidden_size * output_size)

  The input is 3d: [batch, length, depth], consisting of the representations
  of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    as opposed to on individual sequences.  This would allow more freedom
    for individual sequences to be unbalanced.  Unfortunately, that would
    slow down our hacked-up gather-by-matmul implementation.

    TODO(noam): There is no real reason for a single sequence to be the unit
      of equal allocation.  Reshaping the inputs would allow us to pick a
      different unit of equal allocation.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.  We also want to integrate this
  gating/dispatching logic into multi-device mixtures-of-experts.

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim]
    hidden_dim: a mtf.Dimension
    output_dim: a mtf.Dimension
    experts_dim: a mtf.Dimension
    loss_coef: a float scalar
    overhead: multiplicative factor of how much spare capacity to assign

  Returns:
    outputs: a Tensor with shape [batch_dim, length_dim, output_dim]
    loss: a mtf scalar
  """
    batch_dim, length_dim, input_dim = inputs.shape.dims

    # Each sequence sends expert_capacity positions to each expert.
    expert_capacity = min(
        length_dim.size,
        int((length_dim.size * 2 * overhead) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    # This is the learned gating function.
    # shape = [batch_dim, length_dim, experts_dim_unsplit]
    gates = mtf.softmax(dense(inputs, experts_dim_unsplit),
                        experts_dim_unsplit)

    assignment_shape = mtf.TensorShape(
        [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim])

    backward_assignment = mtf.slicewise(functools.partial(
        _truncated_top_2_gating, expert_capacity=expert_capacity), [gates],
                                        output_shape=assignment_shape,
                                        splittable_dims=[batch_dim],
                                        name="backward_assignment")

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  inputs.dtype)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, forward_assignment],
                               mtf.TensorShape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.TensorShape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = dense(expert_inputs,
              hidden_dim,
              expert_dims=[experts_dim],
              activation=mtf.relu,
              name="x0")
    expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.TensorShape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, backward_assignment],
                        mtf.TensorShape([batch_dim, length_dim, output_dim]))

    importance = mtf.reduce_sum(backward_assignment,
                                output_shape=mtf.TensorShape(
                                    [batch_dim, experts_dim_unsplit]))

    loss = cv_squared(importance) * loss_coef
    return output, loss
Example #26
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False,
                           xla_compile=False):
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        mesh_shape = mtf.parse_mesh_shape(hparams.mesh_shape)
        mesh_size = mtf.list_product(mesh_shape)
        if use_tpu:
            mesh_devices = [""] * mesh_size
            mesh_impl = simd_mesh_impl.SimdMeshImpl(
                mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices,
                params["context"].device_assignment)
        else:
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_size
            else:
                assert len(data_parallelism.ps_devices) == mesh_size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            mtf_lr = mtf.infeed(mesh, tf.convert_to_tensor(lr,
                                                           dtype=tf.float32),
                                mtf.TensorShape([]))
            optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.outfeed(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.outfeed(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf_utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.outfeed(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            _remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
Example #27
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        filter_dim = mtf.Dimension("filter_size", hparams.filter_size)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def infeed_to_batch_by_length(x, name):
            return mtf.infeed(mesh,
                              x,
                              mtf.TensorShape([batch_dim, length_dim]),
                              name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.TensorShape([batch_dim, model_dim]))

        targets = infeed_to_batch_by_length(targets, "targets")
        shifted_targets = infeed_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = infeed_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs, _ = mtf_layers.embedding(inputs,
                                             inputs_vocab_dim,
                                             model_dim,
                                             activation_dtype=activation_dtype,
                                             name="inputs_embedding")

        # Create targets content and position embeddings.
        targets_position = mtf.range(mesh, length_dim, dtype=tf.int32)
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.TensorShape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        positional_embedding_var = mtf.get_variable(
            mesh,
            "positional_embedding",
            mtf.TensorShape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)
        x = (mtf.gather(targets_embedding_var, shifted_targets,
                        targets_vocab_dim) +
             mtf.gather(positional_embedding_var, targets_position,
                        targets_vocab_dim))

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(mtf_layers.layer_norm(
                        x, model_dim, name="layer_norm_self_att"),
                                                         None,
                                                         kv_channels,
                                                         heads,
                                                         name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss
Example #28
0
 def infeed_to_batch_by_length(x, name):
     return mtf.infeed(mesh,
                       x,
                       mtf.TensorShape([batch_dim, length_dim]),
                       name=name)
Example #29
0
 def layer_prepostprocess_dropout(x):
     return mtf.dropout(
         x,
         keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
         noise_shape=mtf.TensorShape([batch_dim, model_dim]))
Example #30
0
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q,
            mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        mblocks = local_k.shape.dims[2]
        mlength = local_k.shape.dims[3]
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        # Compatibility between q and k for rest of the blocks.
        # Shape [batch, heads, num_blocks - 1, block_length, local_length]
        attention = mtf.einsum([tail_q, local_k],
                               mtf.TensorShape(
                                   [batch, heads, mblocks, blength, mlength]))
        attention += mask
        attention = mtf.softmax(attention, mlength)

        # Run attention for rest of the blocks.
        # Shape [batch, heads, num_blocks-1, block_length, kv_channels]
        output = mtf.einsum([attention, local_v],
                            mtf.TensorShape(
                                [batch, heads, mblocks, blength, kv_channels]))
        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, output], num_blocks.name)
        final_output = mtf.reshape(
            final_output,
            mtf.TensorShape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))