コード例 #1
0
ファイル: resnet_auto.py プロジェクト: 1926627357/adsl4mtf
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	rows_dim = mtf.Dimension("rows_size", image_height)
	cols_dim = mtf.Dimension("cols_size", image_width)
	channel_dim = mtf.Dimension("image_channel", num_channels)
	classes_dim = mtf.Dimension(name='classesnum',size=classesnum)
	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]),
		mtf.Shape(
			[batch_dim, rows_dim, cols_dim, channel_dim]))
	# x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim])
	# print(x.shape)
	logits = VGG(x, classes_dim=classes_dim,depth=depth)
	logits = mtf.cast(logits,dtype=tf.float32)

	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
コード例 #2
0
def get_bspline_kernel(x, channels, transpose=False, dtype=tf.float32, order=4):
  """Creates a 5x5x5 b-spline kernel.
  Args:
    num_channels: The number of channels of the image to filter.
    dtype: The type of an element in the kernel.
  Returns:
    A tensor of shape `[5, 5, 5, num_channels, num_channels]`.
  """
  mesh = x.mesh
  in_dim = x.shape[-1]
  num_channels = channels.size
  if order == 8:
    kernel = np.array(( 1., 8., 28., 56., 70., 56., 28., 8., 1.), dtype=dtype.as_numpy_dtype())
  elif order == 6:
    kernel = np.array(( 1., 6., 15., 20., 15., 6., 1.), dtype=dtype.as_numpy_dtype())
  elif order==2:
    kernel = np.array(( 1., 2., 1.), dtype=dtype.as_numpy_dtype())
  else:
    kernel = np.array(( 1., 4., 6., 4., 1.), dtype=dtype.as_numpy_dtype())
  size = len(kernel)
  kernel = np.einsum('ij,k->ijk', np.outer(kernel, kernel), kernel)
  kernel /= np.sum(kernel)
  kernel = kernel[:, :, :, np.newaxis, np.newaxis]
  kernel = tf.constant(kernel, dtype=dtype) * tf.eye(num_channels, dtype=dtype)

  fd_dim = mtf.Dimension("fd", size)
  fh_dim = mtf.Dimension("fh", size)
  fw_dim = mtf.Dimension("fw", size)
  if transpose:
    return mtf.import_tf_tensor(mesh, kernel, shape=[fd_dim, fh_dim, fw_dim, channels, in_dim])
  else:
    return mtf.import_tf_tensor(mesh, kernel, shape=[fd_dim, fh_dim, fw_dim, in_dim, channels])
コード例 #3
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """

    # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32
    # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32
    batch_dim = mtf.Dimension("batch", 100)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    hidden_dim = mtf.Dimension("hidden", 1024)
    classes_dim = mtf.Dimension("classes", 10)
    images = mtf.import_tf_tensor(mesh,
                                  image,
                                  shape=[batch_dim, rows_dim, cols_dim])
    labels = mtf.import_tf_tensor(mesh, labels, [batch_dim])
    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim])
    w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
    # einsum is a generalization of matrix multiplication (see numpy.einsum)
    hidden = mtf.relu(
        mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
    logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
    loss = mtf.reduce_mean(
        mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim))

    return logits, loss
コード例 #4
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, rows_dim, cols_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])

    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim])
    b1 = mtf.get_variable(mesh, "b1", [classes_dim])

    logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1)

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
コード例 #5
0
def model_backbone(features, labels, mesh):
    """The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
    id_hldr, wt_hldr = features

    batch_dim = mtf.Dimension("batch", args_opt.batch_size)
    field_dim = mtf.Dimension("field", size=39)
    vocab_dim = mtf.Dimension("vocab_size", 200000)
    embed_dim = mtf.Dimension("embed_size", 80)
    outdim = mtf.Dimension("outdim", 1)
    id_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    wt_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    if args_opt.fp16:
        float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16)
        # id_hldr=mtf.cast(id_hldr,dtype=tf.int32)
        wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16)
    else:
        float16 = None

    logits, embedding_table = network[args_opt.model](id_hldr,
                                                      wt_hldr,
                                                      vocab_dim,
                                                      embed_dim,
                                                      outdim,
                                                      float16=float16)
    logits = mtf.cast(logits, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    if labels is None:
        wide_loss = None
        deep_loss = None
    else:
        labels = mtf.import_tf_tensor(
            mesh, tf.reshape(labels, [args_opt.batch_size]),
            mtf.Shape([batch_dim]))
        wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits(
            logits, labels)
        deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2
        deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss
        wide_loss = mtf.reduce_mean(wide_loss)

    return logits, wide_loss + deep_loss
コード例 #6
0
def mnist_model(image, labels, mesh, hs_t):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh
    hs_t: a mtf.Tensor with shape [batch, hidden_1]
  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
    hs_t: an updated mtf.Tensor
  """
    input_num = 28
    timesteps_num = 28
    classes_num = 10

    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    input_dim = mtf.Dimension("input", input_num)
    timesteps_dim = mtf.Dimension("timesteps", timesteps_num)
    classes_dim = mtf.Dimension("classes", classes_num)
    hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size)
    hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, timesteps_dim, input_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])
    hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1])

    Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2])
    Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2])
    Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim])
    bh = mtf.get_variable(mesh, "bh", [hidden_dim_2])
    by = mtf.get_variable(mesh, "by", [classes_dim])

    x_list = mtf.unstack(x, timesteps_dim)

    for xs_t in x_list:
        hs_t = mtf.tanh(
            mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) +
            mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh)
        logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss, hs_t
コード例 #7
0
ファイル: layers_test.py プロジェクト: qixiuai/mesh
  def testMultiheadAttention(self, kv_channels, heads):
    batch = 2
    length = 8
    channels = 3
    query = tf.random_normal([batch, length, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)
    channels_dim = mtf.Dimension("channels", channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.import_tf_tensor(
        mesh, query,
        shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
    mtf_outputs = mtf.layers.multihead_attention(
        mtf_query,
        memory_antecedent=None,
        mask=None,
        kv_channels=kv_channels_dim,
        heads=heads_dim)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual = self.evaluate(actual_outputs)

    self.assertEqual(actual.shape, query.shape)
コード例 #8
0
  def testConv1d(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    filter_size = 3
    depth_dim = mtf.Dimension("depth", 2)
    length_dim = mtf.Dimension("length", 4)
    output_dim = mtf.Dimension("output", 2)

    x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32)
    mtf_x = mtf.import_tf_tensor(
        mesh, x, shape=mtf.Shape([length_dim, depth_dim]))

    initializer_mock = mock.MagicMock()
    initializer_mock.side_effect = initialize_by_shape({
        (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3],
                                                                [-2, 2]]]],
    })

    mtf_output = mtf.layers.conv1d(
        mtf_x,
        output_dim=output_dim,
        filter_size=filter_size,
        filter_initializer=initializer_mock)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_output = lowering.export_to_tf_tensor(mtf_output)

    self.evaluate(tf.global_variables_initializer())
    self.evaluate(lowering.copy_masters_to_slices())
    actual = self.evaluate(actual_output)

    self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
コード例 #9
0
ファイル: layers_test.py プロジェクト: qixiuai/mesh
  def testDense(self, units, use_bias):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    depth_dim = mtf.Dimension("depth", units)

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
    mtf_outputs = mtf.layers.dense(mtf_inputs,
                                   output_dim=depth_dim,
                                   reduced_dims=[channels_dim],
                                   activation=mtf.relu,
                                   use_bias=use_bias)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    expected_outputs = tf.keras.layers.Dense(units=units,
                                             activation=tf.nn.relu,
                                             use_bias=use_bias)(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual, expected = self.evaluate([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
コード例 #10
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 1
        vocab_size = 4
        model_size = 3
        num_softmaxes = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant(np.array([[1.0, 1.0, 2.0]]) /
                                 model_size**-0.5,
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            # Embedding weights.
            (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
            # Mixture weights.
            (2, 3): [[1, 0, 0], [0, 1, 1]],
            # Context weights
            (2, 3, 3): [
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
                [[0, 0, 1], [0, 1, 0], [1, 0, 0]],
            ],
        })

        vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            num_softmaxes=num_softmaxes)

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual, = self.evaluate([actual_logits])

        expected_priors = scipy.special.softmax([1, 3])
        expected_probs_1 = scipy.special.softmax(np.tanh([1, 1, 2, 2]))
        expected_probs_2 = scipy.special.softmax(np.tanh([2, 1, 1, 1]))
        expected_probs = (expected_priors[0] * expected_probs_1 +
                          expected_priors[1] * expected_probs_2)
        expected_logits = np.log(expected_probs)

        self.assertAllClose(actual, [expected_logits])
コード例 #11
0
ファイル: layers_test.py プロジェクト: stjordanis/mesh
    def testDenseReluDense(self):
        batch = 2
        channels = 3
        hidden = 5
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)
        hidden_dim = mtf.Dimension("hidden", hidden)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs,
                                                  hidden_channels=hidden_dim,
                                                  is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, inputs.shape)
コード例 #12
0
    def testMaskedLocalAttention1D(self, batch, length, io_channels,
                                   kv_channels, heads, window_size):
        length_q = length
        query = tf.random_normal([batch, length_q, io_channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_q_dim = mtf.Dimension("length_q", length_q)
        io_channels_dim = mtf.Dimension("io_channels", io_channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
        mtf_outputs = mtf.layers.masked_local_attention_1d(
            mtf_query,
            kv_channels=kv_channels_dim,
            heads=heads_dim,
            window_size=window_size)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, (batch, length_q, io_channels))
コード例 #13
0
ファイル: ops_test.py プロジェクト: tspannhw/mesh
 def testGraph(self):
     graph = mtf.Graph()
     self.assertLen(graph.operations, 0)
     self.assertLen(graph.tensors, 0)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     mesh = mtf.Mesh(graph, "mesh_test")
     _ = mtf.import_tf_tensor(mesh,
                              tf_tensor=tf.constant(0.),
                              shape=mtf.Shape([]))
     self.assertLen(graph.operations, 1)
     self.assertLen(graph.tensors, 1)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
     self.assertLen(graph.operations, 2)
     self.assertLen(graph.tensors, 2)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 1)
     _ = mtf.get_variable(mesh,
                          "variable_1",
                          mtf.Shape([]),
                          trainable=False)
     self.assertLen(graph.operations, 3)
     self.assertLen(graph.tensors, 3)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 2)
コード例 #14
0
ファイル: bert_test.py プロジェクト: lynex/mesh
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'num_heads:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                batch_dim = mtf.Dimension('batch', batch_size)
                seq_dim = mtf.Dimension('seq', seq_length)

                input_ids = tf.random.uniform((batch_size, seq_length),
                                              minval=0,
                                              maxval=vocab_size,
                                              dtype=tf.int32)
                mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                                     [batch_dim, seq_dim])

                model = bert_lib.BertModel(config=bert_config,
                                           is_training=True,
                                           input_ids=mtf_input_ids,
                                           input_mask=None,
                                           token_type_ids=None)
                pooled = model.get_pooled_output()
                lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                return lowering.export_to_tf_tensor(pooled)
コード例 #15
0
ファイル: layers_test.py プロジェクト: stjordanis/mesh
    def testCorr2DInput(self):
        batch = 4
        channels = 3
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tfp.stats.correlation(inputs,
                                                 sample_axis=0,
                                                 event_axis=1)
        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertEqual(actual.shape, expected.shape)
        self.assertAllClose(actual, expected)
コード例 #16
0
ファイル: toy_model_tpu.py プロジェクト: minhtcai/mesh
def toy_model(features, mesh):
  """A toy model implemented by mesh tensorlfow."""
  batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
  io_dim = mtf.Dimension('io', FLAGS.io_size)

  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  activation_dtype = tf.as_dtype(FLAGS.activation_dtype)

  x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
  x = mtf.cast(x, activation_dtype)
  h = x
  for lnum in xrange(1, FLAGS.num_hidden_layers + 2):
    if lnum + 1 == FLAGS.num_hidden_layers + 2:
      # output layer
      dim = io_dim
    elif lnum % 2 == 0:
      dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
    else:
      dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
    h = mtf.layers.dense(
        h, dim,
        use_bias=False,
        master_dtype=master_dtype,
        slice_dtype=slice_dtype,
        name='layer_%d' % lnum)
  y = h

  loss = mtf.reduce_mean(mtf.square(y - x))
  return y, loss
コード例 #17
0
    def testRecomputeGrad(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        # let's differentiate x^2 + x
        # dy/dx = 2x+1
        def x_squared_plus_x(x):
            return x * x + x

        x = tf.constant([5, 10], dtype=tf.float32)
        dy = tf.constant([2, 3], dtype=tf.float32)
        two = mtf.Dimension("two", 2)
        expected_y = tf.constant([30, 110], dtype=tf.float32)
        expected_dx = tf.constant([22, 63], dtype=tf.float32)
        mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two]))
        mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two]))
        mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
        [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            shape="processors:2", layout="two:processors", devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_y = lowering.export_to_tf_tensor(mtf_y)
        actual_dx = lowering.export_to_tf_tensor(mtf_dx)
        self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y))
        self.assertAllEqual(self.evaluate(actual_dx),
                            self.evaluate(expected_dx))
コード例 #18
0
ファイル: ops_test.py プロジェクト: bruinxiong/mesh-1
  def testNthSmallestReduceSecondDim(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]])
    n = 0  # find smallest element (n is zero-indexed)
    reduced_dim = b_dim
    expected_outputs = tf.constant([1, 2, 3, 4, 5, 5])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_outputs = mtf.nth_smallest_element(
        mtf_inputs, n, reduced_dim, "test_nth_smallest")
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="all:2", layout="a:all", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
    self.assertAllEqual(self.evaluate(actual_outputs),
                        self.evaluate(expected_outputs))
コード例 #19
0
ファイル: layers_test.py プロジェクト: stjordanis/mesh
    def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
コード例 #20
0
    def testDynamicText2self_unpacked(self):
        batch = 2
        length = 5
        input_tensors = {
            "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]],
            "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]],
        }
        expected_output_tensors = {
            "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0],
                        [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]],
        }
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)

        input_shape = mtf.Shape([batch_dim, length_dim])
        mtf_features = {
            k: mtf.import_tf_tensor(mesh, v, input_shape)
            for k, v in input_tensors.items()
        }
        mtf_outputs = utils._dynamic_text2self(mtf_features)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        for k, v in expected_output_tensors.items():
            out = lowering.export_to_tf_tensor(mtf_outputs[k])
            actual = self.evaluate(out)
            self.assertAllEqual(actual, v)
コード例 #21
0
ファイル: layers_test.py プロジェクト: stjordanis/mesh
    def testDotProductAttention(self, batch, heads, length_q, length_kv,
                                depth_k, depth_v):
        query = tf.random_normal([batch, heads, length_q, depth_k])
        key = tf.random_normal([batch, heads, length_kv, depth_k])
        value = tf.random_normal([batch, heads, length_kv, depth_v])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        heads_dim = mtf.Dimension("heads", heads)
        length_q_dim = mtf.Dimension("length_q", length_q)
        length_kv_dim = mtf.Dimension("length_kv", length_kv)
        depth_k_dim = mtf.Dimension("depth_k", depth_k)
        depth_v_dim = mtf.Dimension("depth_v", depth_v)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim]))
        mtf_key = mtf.import_tf_tensor(
            mesh,
            key,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_k_dim]))
        mtf_value = mtf.import_tf_tensor(
            mesh,
            value,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_v_dim]))
        mtf_outputs = mtf.layers.dot_product_attention(mtf_query,
                                                       mtf_key,
                                                       mtf_value,
                                                       mask=None,
                                                       is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
コード例 #22
0
def generate_heterogeneous_expert_masks(mask_info, num_experts, experts_dim,
                                        mesh, expert_width):
    """Generates the heterogeous expert masks.

  Example mask_info format:
    mask_info = [{'percent_number': .5, 'layers': 1, 'width':1},
                 {'percent_number': .5, 'layers': 2, 'width':2}]

  Args:
    mask_info: list of dicts.
    num_experts: number of experts in the model
    experts_dim: mtf dimension for experts (partitioned)
    mesh: mesh object
    expert_width: int, default expert width which will be modified by the mask

  Returns:
    mask of shape [moe_num_layers, num_experts, hidden_size].
  """
    # Get max num layers
    max_layers = max([m["layers"] for m in mask_info])
    # Get max width
    max_width = max([m["width"] for m in mask_info])
    # Will be shape [max_width, max_layers, num_experts]
    expert_mask = np.zeros([max_width, max_layers, 0])
    for idx, mask_i in enumerate(mask_info):
        if mask_i["percent_number"] < 1.0:
            num_experts_in_mask = int(num_experts * mask_i["percent_number"])
        else:
            num_experts_in_mask = int(mask_i["percent_number"])
        # if percent_number=1 either because homogeneous experts or just 1 expert
        # in which case num_experts_in_mask will be reset to num_experts
        # creating one homogeneous group
        if idx == (len(mask_info) - 1):  # Last position
            num_experts_in_mask_tmp = num_experts - expert_mask.shape[2]
            if num_experts_in_mask_tmp != num_experts_in_mask:
                tf.logging.info(
                    "Expert layer probabilities do not evenly divide "
                    "the number of experts: {} {}".format(
                        num_experts_in_mask, num_experts_in_mask_tmp))
                num_experts_in_mask = num_experts_in_mask_tmp
        mask = np.zeros([int(max_width), int(max_layers), num_experts_in_mask])
        # Zero out the last layers of the experts.
        mask[:(mask_i["width"] * expert_width), :mask_i["layers"], :] = 1
        expert_mask = np.concatenate([expert_mask, mask], axis=2)  # expert dim
    assert expert_mask.shape[2] == num_experts
    tf.logging.info("heterogeneous mask: {}".format(expert_mask))

    # Now import the numpy mask into Mesh TF.
    layers_dim = mtf.Dimension("num_expert_layers", max_layers)
    width_dim = mtf.Dimension("expert_hidden", max_width)
    expert_mask_tf = tf.convert_to_tensor(expert_mask)
    expert_mask_mtf = mtf.import_tf_tensor(
        mesh,
        tf_tensor=expert_mask_tf,
        shape=[width_dim, layers_dim, experts_dim])
    return expert_mask_mtf
コード例 #23
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    mesh = mtf.Mesh(graph, 'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], 
            gpus_per_node=num_gpus // num_nodes)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape)
    mtf_labels = mtf.import_tf_tensor(mesh, labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
コード例 #24
0
  def test_ids_to_embedding_correctlyEmbeds(self):
    seq_len = 5
    vocab_size = 5
    model_size = 2
    gate_embedding_size = 1
    frequent_token_fraction = 0.4

    vocab_dim = mtf.Dimension('vocab', vocab_size)
    model_dim = mtf.Dimension('model', model_size)
    length_dim = mtf.Dimension('length', seq_len)

    context = mock.MagicMock()
    context.train = False

    ids = tf.constant([0, 1, 2, 3, 4], dtype=tf.int32)
    mtf_ids = mtf.import_tf_tensor(
        self.mesh, ids, shape=mtf.Shape([length_dim]))

    self.initializer_mock.side_effect = initialize_by_shape({
        # Embedding weights.
        (5, 2): list(range(10)),
        # Context weights.
        (4, 2, 2): list(range(16)),
        # Prior weights.
        (3, 1, 2): list(range(6)),
        # Prior vocab vector.
        (2, 1): list(range(2)),
        # Prior gates vector.
        (3, 2): list(range(6)),
        # Prior bias.
        (2, 3): list(range(6)),
    })

    vocab_embedding = vocab_embeddings.Mixtape(
        self.mesh,
        vocab_dim,
        output_dim=model_dim,
        variable_dtype=self.variable_dtype,
        name='embedding',
        ensemble_dim=None,
        gate_embedding_size=gate_embedding_size,
        frequent_token_fraction=frequent_token_fraction)

    mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[''])
    lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
    actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

    self.evaluate(tf.global_variables_initializer())
    self.evaluate(lowering.copy_masters_to_slices())
    actual = self.evaluate([actual_embedding])[0]

    self.assertAllClose(actual, np.reshape(list(range(10)), (5, 2)))
コード例 #25
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 4
        vocab_size = 5
        model_size = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]],
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            (2, 2): [[0, 1], [2, 0]],
            (3, 1): [[1], [2], [3]],
            (1, 2): [[1], [2]],
        })

        vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            clusters=[{
                'token_count': 2,
                'embedding_size': 2
            }, {
                'token_count': 3,
                'embedding_size': 1
            }])

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_logits])[0]

        self.assertAllClose(
            actual,
            model_size**-0.5 * np.array([[0, 2, 1, 2, 3], [1, 0, 2, 4, 6],
                                         [1, 2, 3, 6, 9], [1, 4, 4, 8, 12]]))
コード例 #26
0
ファイル: toy_model_tpu.py プロジェクト: tspannhw/mesh
def toy_model(features, mesh):
    """A toy model implemented by mesh tensorlfow."""
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
    hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
    io_dim = mtf.Dimension('io', FLAGS.io_size)

    x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
    h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False)
    y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False)

    loss = mtf.reduce_sum(mtf.square(y - x))
    return y, loss
コード例 #27
0
def Replication2(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([1, 4]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape)
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #28
0
def toy_model(features, mesh):
    """A toy model implemented by mesh tensorlfow."""
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
    io_dim = mtf.Dimension('io', FLAGS.io_size)

    master_dtype = tf.as_dtype(FLAGS.master_dtype)
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)

    x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
    x = mtf.cast(x, activation_dtype)
    h = x
    for lnum in range(1, FLAGS.num_hidden_layers + 2):
        if lnum + 1 == FLAGS.num_hidden_layers + 2:
            # output layer
            dim = io_dim
        elif lnum % 2 == 0:
            dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
        else:
            dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
        h = mtf.layers.dense(h,
                             dim,
                             use_bias=False,
                             master_dtype=master_dtype,
                             slice_dtype=slice_dtype,
                             name='layer_%d' % lnum)
    y = h
    g = tf.train.get_global_step()
    if FLAGS.step_with_nan >= 0:
        # Trigger NaN in the forward pass, this is used for testing whether
        # MeshTensorFlow can handle occasional NaN value.
        y += mtf.import_tf_tensor(
            mesh,
            tf.divide(
                0.0,
                tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0.,
                        lambda: 1.)), mtf.Shape([]))

    loss = mtf.reduce_mean(mtf.square(y - x))
    return y, loss
コード例 #29
0
def Split3(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #30
0
def NoConcatSplit(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \
            mesh1:GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #31
0
 def import_to_batch_by_length(x, name):
   return mtf.import_tf_tensor(
       mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)
コード例 #32
0
ファイル: mtf_resnet.py プロジェクト: qixiuai/tensor2tensor
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss
コード例 #33
0
ファイル: mtf_model.py プロジェクト: qixiuai/tensor2tensor
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])