コード例 #1
0
    def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = common_layers.weights_nonzero(inputs)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
コード例 #2
0
    def testLayerNorm(self):
        batch = 2
        channels = 3
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf_layers.layer_norm(mtf_inputs, dim=channels_dim)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = common_layers.layer_norm(inputs)
        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertEqual(actual.shape, expected.shape)
コード例 #3
0
    def create_positional_emb_2d(self, targets, max_length_dim, model_dim):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        rows_dim = mtf.Dimension("rows", hparams.img_len)
        cols_dim = mtf.Dimension("cols",
                                 hparams.img_len * hparams.num_channels)

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))
        return position_x + position_y
コード例 #4
0
  def testDense(self, units, use_bias):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    depth_dim = mtf.Dimension("depth", units)

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.dense(mtf_inputs,
                                   output_dim=depth_dim,
                                   reduced_dims=[channels_dim],
                                   activation=mtf.relu,
                                   use_bias=use_bias)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    expected_outputs = tf.keras.layers.Dense(units=units,
                                             activation=tf.nn.relu,
                                             use_bias=use_bias)(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual, expected = sess.run([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
コード例 #5
0
  def testDenseReluDense(self):
    batch = 2
    channels = 3
    hidden = 5
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    hidden_dim = mtf.Dimension("hidden", hidden)

    mtf_inputs = mtf.infeed(mesh, inputs,
                            shape=mtf.TensorShape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs,
                                              hidden_channels=hidden_dim)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[1], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.outfeed(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, inputs.shape)
コード例 #6
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):

  @parameterized.parameters(
      (mtf.Dimension(name="x", size=5),),
      (("x", 5),),
  )
  def testConvertToDimension(self, inputs):
    dimension = mtf.convert_to_dimension(inputs)
    self.assertEqual(dimension.name, "x")
    self.assertEqual(dimension.size, 5)

  def testConvertToDimensionGenericInputs(self):
    dimension = mtf.convert_to_dimension(None)
    self.assertEqual(dimension, None)
    with self.assertRaises(TypeError):
      mtf.convert_to_dimension(5)

  @parameterized.parameters(
      (mtf.Shape([mtf.Dimension(name="x", size=4),
                  mtf.Dimension(name="y", size=8)]),),
      ("x:4;y:8",),
      ("x:4.y:8",),
      ("x:4 y:8",),
      ("x:4,y:8",),
  )
  def testConvertToShape(self, inputs):
    shape = mtf.convert_to_shape(inputs)
    self.assertEqual(shape, mtf.Shape([mtf.Dimension(name="x", size=4),
                                       mtf.Dimension(name="y", size=8)]))

  def testConvertToShapeGenericInputs(self):
    shape = mtf.convert_to_shape(None)
    self.assertEqual(shape, None)
    with self.assertRaises(ValueError):
      mtf.convert_to_shape("x;4")

  @parameterized.parameters(
      (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]),),
      ("d_ff:model;heads:model",),
      ("d_ff:model.heads:model",),
      ("d_ff:model heads:model",),
      ("d_ff:model,heads:model",),
      ([("d_ff", "model"), ("heads", "model")],),
  )
  def testConvertToLayoutRules(self, inputs):
    layout_rules = mtf.convert_to_layout_rules(inputs)
    self.assertEqual(
        layout_rules._pairs,
        mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

  def testConvertToLayoutRulesGenericInputs(self):
    with self.assertRaises(ValueError):
      mtf.convert_to_layout_rules("d_ff;heads")
コード例 #7
0
def toy_model(features, mesh):
  """A toy model implemented by mesh tensorlfow."""
  batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
  hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
  io_dim = mtf.Dimension('io', FLAGS.io_size)

  x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
  h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False)
  y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False)

  loss = mtf.reduce_sum(mtf.square(y - x))
  return y, loss
コード例 #8
0
 def batch_dims(self):
     hparams = self._hparams
     if hparams.outer_batch_size == 0:
         return [mtf.Dimension("batch", hparams.batch_size)]
     else:
         if hparams.batch_size % hparams.outer_batch_size != 0:
             raise ValueError(
                 "hparams.outer_batch_size must divide hparams.batch_size")
         return [
             mtf.Dimension("outer_batch", hparams.outer_batch_size),
             mtf.Dimension("inner_batch",
                           hparams.batch_size // hparams.outer_batch_size)
         ]
コード例 #9
0
ファイル: mtf_layers.py プロジェクト: yangliuy/tensor2tensor
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    stddev = (hidden_channels.size * io_channels.size) ** -0.25
    io = mtf.Dimension("io", 2)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        mtf.Shape([io, io_channels, hidden_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        activation_dtype=x.dtype)
    wi, wo = mtf.unstack(w, io)
    h = mtf.relu(mtf.einsum([x, wi]))
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return mtf.einsum([h, wo])
コード例 #10
0
ファイル: mtf_layers.py プロジェクト: yangliuy/tensor2tensor
def multihead_attention_vars(
    mesh, heads, io_channels, kv_channels, activation_dtype):
  """Create Parameters for Multihead Attention.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    activation_dtype: a tf.dtype

  Returns:
    q_var: a Tensor with shape [heads, io_channels, kv_channels]
    k_var: a Tensor with shape [heads, io_channels, kv_channels]
    v_var: a Tensor with shape [heads, io_channels, kv_channels]
    o_var: a Tensor with shape [heads, io_channels, kv_channels]
  """
  qkvo = mtf.Dimension("qkvo", 4)
  qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25)
  v_stddev = io_channels.size ** -0.5
  o_stddev = (io_channels.size * heads.size) ** -0.5
  def qkvo_initializer(shape,
                       dtype=None,
                       partition_info=None,
                       verify_shape=None):
    del partition_info, verify_shape
    return tf.random_normal(shape, dtype=dtype) * tf.reshape(
        [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1])
  var = mtf.get_variable(
      mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]),
      initializer=qkvo_initializer, activation_dtype=activation_dtype)
  q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo)
  return q_var, k_var, v_var, o_var
コード例 #11
0
ファイル: mtf_layers.py プロジェクト: sougata09/tensor2tensor
def attention_bias_local_block(mesh,
                               block_length,
                               memory_length,
                               dtype=tf.int32):
    """Bias for attention for local blocks where attention to right is disallowed.

  Create the bias matrix by using two separate masks, one for the memory part
  which doesn't overlap with the query and second which interacts with the query
  and should be disallowed to look to the right of the current query position.

  Args:
    mesh: a MeshTensorflow object
    block_length: a mtf.Dimension
    memory_length: a mtf.Dimension
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [block_length, memory_length]
  """
    memory_length = mtf.Dimension(memory_length.name, block_length.size)
    memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype)

    mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype),
                             mtf.range(mesh, memory_length, dtype=dtype)),
                    dtype=dtype)
    mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name),
                    dtype=tf.float32) * -1e9
    return mask
コード例 #12
0
def _concat_equal_sizes(xs, dim, new_dim_name):
    axis = xs[0].shape.dims.index(dim)
    ret = mtf.stack(xs, "tmp_concat", axis)
    new_shape = mtf.TensorShape(
        xs[0].shape.dims[:axis] +
        [mtf.Dimension(new_dim_name, dim.size * len(xs))] +
        xs[0].shape.dims[axis + 1:])
    return mtf.reshape(ret, new_shape)
コード例 #13
0
    def mtf_model_fn(self, features, mesh):
        hparams = self._hparams
        # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size])
        tf_x = tf.matmul(
            tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size),
                       [hparams.batch_size, 1]),
            tf.reshape(tf.lin_space(0., 1.0, hparams.io_size),
                       [1, hparams.io_size]))
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        io_dim = mtf.Dimension("io", hparams.io_size)

        x = mtf.infeed_fully_replicated(mesh, tf_x,
                                        mtf.TensorShape([batch_dim, io_dim]))
        h = mtf_layers.dense(x, hidden_dim, name="layer1", use_bias=False)
        y = mtf_layers.dense(h, io_dim, name="layer2", use_bias=False)
        loss = mtf.reduce_sum(mtf.square(y - x))
        return None, loss
コード例 #14
0
    def test_variable_placer(self):
        sizes = [100, 0, 0, 0]
        device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3']

        with tf.Graph().as_default() as g:
            var_placer = mtf_utils.BalancedVariablePlacer(device_list, sizes)
            graph = mtf.Graph()
            mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

            hidden_dim = mtf.Dimension('hidden', 10)
            output_dim = mtf.Dimension('output_feature', 10)

            for i in xrange(5):
                # Each variable takes 400 Bytes, and will be placed from cpu:1.
                mtf.get_variable(mesh, 'w{}'.format(i),
                                 [hidden_dim, output_dim])

            for i in xrange(5):
                var = g.get_tensor_by_name('w{}:0'.format(i))
                device = (i + 1) % len(device_list)
                self.assertEqual('cpu:{}'.format(device), var.device)
コード例 #15
0
ファイル: mnist.py プロジェクト: sumehta/tensor2tensor
def mnist_model(image, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  rows_dim = mtf.Dimension("rows", 28)
  cols_dim = mtf.Dimension("cols", 28)
  classes_dim = mtf.Dimension("classes", 10)
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                           mtf.Shape([batch_dim, rows_dim, cols_dim]))
  h1 = mtf_layers.dense(
      x, hidden_dim1, reduced_dims=[rows_dim, cols_dim],
      activation=mtf.relu, name="hidden1")
  h2 = mtf_layers.dense(
      h1, hidden_dim2, activation=mtf.relu, name="hidden2")
  logits = mtf_layers.dense(h2, classes_dim, name="logits")
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
コード例 #16
0
    def testMultiheadAttention(self, kv_channels, heads):
        batch = 2
        length = 8
        channels = 3
        query = tf.random_normal([batch, length, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)
        channels_dim = mtf.Dimension("channels", channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
        mtf_outputs = mtf_layers.multihead_attention(
            mtf_query,
            memory_antecedent=None,
            mask=None,
            kv_channels=kv_channels_dim,
            heads=heads_dim)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            sess.run(tf_group)
            actual = sess.run(actual_outputs)

        self.assertEqual(actual.shape, query.shape)
コード例 #17
0
    def testDotProductAttention(self, batch, heads, length_q, length_kv,
                                depth_k, depth_v):
        query = tf.random_normal([batch, heads, length_q, depth_k])
        key = tf.random_normal([batch, heads, length_kv, depth_k])
        value = tf.random_normal([batch, heads, length_kv, depth_v])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        heads_dim = mtf.Dimension("heads", heads)
        length_q_dim = mtf.Dimension("length_q", length_q)
        length_kv_dim = mtf.Dimension("length_kv", length_kv)
        depth_k_dim = mtf.Dimension("depth_k", depth_k)
        depth_v_dim = mtf.Dimension("depth_v", depth_v)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim]))
        mtf_key = mtf.import_tf_tensor(
            mesh,
            key,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_k_dim]))
        mtf_value = mtf.import_tf_tensor(
            mesh,
            value,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_v_dim]))
        mtf_outputs = mtf_layers.dot_product_attention(mtf_query,
                                                       mtf_key,
                                                       mtf_value,
                                                       mask=None)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            sess.run(tf_group)
            actual = sess.run(actual_outputs)

        self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
コード例 #18
0
    def testMaskedLocalAttention1D(self, batch, length, io_channels,
                                   kv_channels, heads, block_length):
        length_q = length
        length_m = length
        query = tf.random_normal([batch, length_q, io_channels])
        memory = tf.random_normal([batch, length_m, io_channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_q_dim = mtf.Dimension("length_q", length_q)
        length_m_dim = mtf.Dimension("length_m", length_m)
        io_channels_dim = mtf.Dimension("io_channels", io_channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
        mtf_memory = mtf.import_tf_tensor(
            mesh,
            memory,
            shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim]))
        mtf_outputs = mtf_layers.masked_local_attention_1d(
            mtf_query,
            mtf_memory,
            kv_channels=kv_channels_dim,
            heads=heads_dim,
            block_length=block_length)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            sess.run(tf_group)
            actual = sess.run(actual_outputs)

        self.assertEqual(actual.shape, (batch, length_q, io_channels))
コード例 #19
0
  def testMeshImpl(self):
    shape = mtf.Shape([mtf.Dimension("batch", 4),
                       mtf.Dimension("model", 8)])
    layout_rules = mtf.LayoutRules([("batch", "batch"),
                                    ("d_ff", "model"),
                                    ("heads", "model")])
    mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
    self.assertEqual(mesh_impl.shape, shape)
    self.assertEqual(mesh_impl.ndims, len(shape))
    self.assertEqual(mesh_impl.layout_rules, layout_rules)
    self.assertEqual(mesh_impl.size, shape.size)
    self.assertTrue(mesh_impl.supports_control_dependencies)

    batch = mtf.Dimension("batch", 128)
    length = mtf.Dimension("length", 500)
    d_ff = mtf.Dimension("d_ff", 2048)
    heads = mtf.Dimension("heads", 8)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
    self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
                     mtf.TensorLayout([0, None, 1]))
コード例 #20
0
ファイル: mtf_layers.py プロジェクト: repoloper/tensor2tensor
def moe_v0(inputs,
           hidden_dim,
           output_dim,
           experts_dim,
           loss_coef=1e-3,
           overhead=1.0):
    """Local mixture of experts that works well on TPU.

  See https://arxiv.org/abs/1701.06538

  There are num_experts expert networks, each containing a relu-activated
  hidden layer of size hidden_size, followed by an output projection.

  The number of parameters is thus:
    num_experts * (input_size * hidden_size + hidden_size * output_size)

  The input is 3d: [batch, length, depth], consisting of the representations
  of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    as opposed to on individual sequences.  This would allow more freedom
    for individual sequences to be unbalanced.  Unfortunately, that would
    slow down our hacked-up gather-by-matmul implementation.

    TODO(noam): There is no real reason for a single sequence to be the unit
      of equal allocation.  Reshaping the inputs would allow us to pick a
      different unit of equal allocation.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.  We also want to integrate this
  gating/dispatching logic into multi-device mixtures-of-experts.

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim]
    hidden_dim: a mtf.Dimension
    output_dim: a mtf.Dimension
    experts_dim: a mtf.Dimension
    loss_coef: a float scalar
    overhead: multiplicative factor of how much spare capacity to assign

  Returns:
    outputs: a Tensor with shape [batch_dim, length_dim, output_dim]
    loss: a mtf scalar
  """
    batch_dim, length_dim, input_dim = inputs.shape.dims

    # Each sequence sends expert_capacity positions to each expert.
    expert_capacity = min(
        length_dim.size,
        int((length_dim.size * 2 * overhead) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    # This is the learned gating function.
    # shape = [batch_dim, length_dim, experts_dim_unsplit]
    gates = mtf.softmax(dense(inputs, experts_dim_unsplit),
                        experts_dim_unsplit)

    assignment_shape = mtf.TensorShape(
        [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim])

    backward_assignment = mtf.slicewise(functools.partial(
        _truncated_top_2_gating, expert_capacity=expert_capacity), [gates],
                                        output_shape=assignment_shape,
                                        splittable_dims=[batch_dim],
                                        name="backward_assignment")

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  inputs.dtype)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, forward_assignment],
                               mtf.TensorShape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.TensorShape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = dense(expert_inputs,
              hidden_dim,
              expert_dims=[experts_dim],
              activation=mtf.relu,
              name="x0")
    expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.TensorShape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, backward_assignment],
                        mtf.TensorShape([batch_dim, length_dim, output_dim]))

    importance = mtf.reduce_sum(backward_assignment,
                                output_shape=mtf.TensorShape(
                                    [batch_dim, experts_dim_unsplit]))

    loss = cv_squared(importance) * loss_coef
    return output, loss
コード例 #21
0
ファイル: mtf_layers.py プロジェクト: repoloper/tensor2tensor
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q,
            mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        mblocks = local_k.shape.dims[2]
        mlength = local_k.shape.dims[3]
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        # Compatibility between q and k for rest of the blocks.
        # Shape [batch, heads, num_blocks - 1, block_length, local_length]
        attention = mtf.einsum([tail_q, local_k],
                               mtf.TensorShape(
                                   [batch, heads, mblocks, blength, mlength]))
        attention += mask
        attention = mtf.softmax(attention, mlength)

        # Run attention for rest of the blocks.
        # Shape [batch, heads, num_blocks-1, block_length, kv_channels]
        output = mtf.einsum([attention, local_v],
                            mtf.TensorShape(
                                [batch, heads, mblocks, blength, kv_channels]))
        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, output], num_blocks.name)
        final_output = mtf.reshape(
            final_output,
            mtf.TensorShape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
コード例 #22
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                             mtf.Shape([batch_dim, rows_dim, cols_dim]))
    x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim])

    # add some convolutional layers to demonstrate that convolution works.
    # TODO(noam): get spatially-partitioned convolution working.
    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", 32)
    filters2_dim = mtf.Dimension("filters2", 32)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(mtf.conv2d(x, kernel1))
    f2 = mtf.relu(mtf.conv2d(f1, kernel2))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=[rows_dim, cols_dim],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
コード例 #23
0
 def experts_dim(self):
   return mtf.Dimension("experts", self._hparams.moe_num_experts)
コード例 #24
0
 def kv_dim(self):
   return mtf.Dimension("d_kv", self._hparams.d_kv)
コード例 #25
0
 def heads_dim(self):
   return mtf.Dimension("heads", self._hparams.num_heads)
コード例 #26
0
 def memory_length_dim(self):
   return mtf.Dimension("memory_length", self._hparams.max_length)
コード例 #27
0
  def _decoder_layer_stack_incremental(self,
                                       x,
                                       step_num,
                                       encdec_tensors,
                                       self_attention_k,
                                       self_attention_v,
                                       encdec_attention_mask=None):
    """Decoder layer stack during inference.

    We are processing only one position at a time.

    The self-attention keys and values have already been computed for
    previous positions.  In addition to the decoder output, we need to
    produce the updated self-attention keys and values.

    If there is an encoder, then additional Tensors are supplied in
    encdec_tensors, which give us the keys and values for encoder-decoder
    attention as well as the weight matrices q_var and o_var.

    Args:
      x: a mtf.Tensor with shape [batch_dim, model_dim]
      step_num: an mtf integer Scalar
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.

    Returns:
      y: a mtf.Tensor with shape [batch_dim, model_dim]
      new_self_attention_k: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_k
      new_self_attention_v: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_v

    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    num_layers = hparams.num_decoder_layers
    num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    new_self_attention_k = []
    new_self_attention_v = []
    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        y, new_k, new_v = mtf_layers.multihead_self_attention_incremental(
            normalize(x),
            prev_k=self_attention_k[layer],
            prev_v=self_attention_v[layer],
            step_num=step_num,
            name="self_attention")
        new_self_attention_k.append(new_k)
        new_self_attention_v.append(new_v)
        x += y
        if encdec_tensors is not None:
          # Encoder-Decoder attention layer
          q_var, o_var, k, v = encdec_tensors[layer]
          x += mtf_layers.multihead_encdec_attention_incremental(
              normalize(x),
              q_var, o_var, k, v,
              encdec_attention_mask,
              name="encdec_attention")
        # ffn layer
        x += self._feedforward_layer(normalize(x), hparams)
    x = normalize(x)
    assert not layer_norm_vars
    return x, new_self_attention_k, new_self_attention_v
コード例 #28
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertLen(graph.operations, 0)
        self.assertLen(graph.tensors, 0)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertLen(graph.tensors, 1)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.tensors, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.tensors, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        with self.test_session() as sess:
            inputs_value, outputs_value = sess.run([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertEqual(mesh_impl.ndims, len(shape))
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))
コード例 #29
0
 def testConvertToShape(self, inputs):
     shape = mtf.convert_to_shape(inputs)
     self.assertEqual(
         shape, mtf.Shape([mtf.Dimension("x", 4),
                           mtf.Dimension("y", 8)]))
コード例 #30
0
 def feedforward_dim(self):
   return mtf.Dimension("d_ff", self._hparams.d_ff)