示例#1
0
  def testNthSmallestReduceSecondDim(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]])
    n = 0  # find smallest element (n is zero-indexed)
    reduced_dim = b_dim
    expected_outputs = tf.constant([1, 2, 3, 4, 5, 5])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_outputs = mtf.nth_smallest_element(
        mtf_inputs, n, reduced_dim, "test_nth_smallest")
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="all:2", layout="a:all", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
    self.assertAllEqual(self.evaluate(actual_outputs),
                        self.evaluate(expected_outputs))
示例#2
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
示例#3
0
 def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
     mtf_samples = self.sample(features, mesh)
     lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
     outputs = lowering.export_to_tf_tensor(mtf_samples)
     if self.has_input:
         ndims = len(outputs.shape.as_list())
         actual_batch_size = tf.shape(features["inputs"])[0]
         outputs = tf.slice(outputs, [0] * ndims,
                            [actual_batch_size] + [-1] * (ndims - 1))
     predictions = {
         "outputs": outputs,
         "targets": features.get("infer_targets", features.get("inputs")),
         "inputs": features.get("inputs"),
     }
     if use_tpu:
         t2t_model.remove_summaries()
         return tpu_estimator.TPUEstimatorSpec(
             mode=tf.estimator.ModeKeys.PREDICT,
             predictions=predictions,
             prediction_hooks=[mtf.MtfRestoreHook(lowering)])
     else:
         return tf.estimator.EstimatorSpec(
             tf.estimator.ModeKeys.PREDICT,
             predictions=predictions,
             prediction_hooks=[mtf.MtfRestoreHook(lowering)])
示例#4
0
文件: bert_test.py 项目: lynex/mesh
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'num_heads:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                batch_dim = mtf.Dimension('batch', batch_size)
                seq_dim = mtf.Dimension('seq', seq_length)

                input_ids = tf.random.uniform((batch_size, seq_length),
                                              minval=0,
                                              maxval=vocab_size,
                                              dtype=tf.int32)
                mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                                     [batch_dim, seq_dim])

                model = bert_lib.BertModel(config=bert_config,
                                           is_training=True,
                                           input_ids=mtf_input_ids,
                                           input_mask=None,
                                           token_type_ids=None)
                pooled = model.get_pooled_output()
                lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                return lowering.export_to_tf_tensor(pooled)
示例#5
0
    def testRecomputeGrad(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        # let's differentiate x^2 + x
        # dy/dx = 2x+1
        def x_squared_plus_x(x):
            return x * x + x

        x = tf.constant([5, 10], dtype=tf.float32)
        dy = tf.constant([2, 3], dtype=tf.float32)
        two = mtf.Dimension("two", 2)
        expected_y = tf.constant([30, 110], dtype=tf.float32)
        expected_dx = tf.constant([22, 63], dtype=tf.float32)
        mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two]))
        mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two]))
        mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
        [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            shape="processors:2", layout="two:processors", devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_y = lowering.export_to_tf_tensor(mtf_y)
        actual_dx = lowering.export_to_tf_tensor(mtf_dx)
        self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y))
        self.assertAllEqual(self.evaluate(actual_dx),
                            self.evaluate(expected_dx))
示例#6
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
示例#7
0
    def testDynamicText2self_unpacked(self):
        batch = 2
        length = 5
        input_tensors = {
            "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]],
            "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]],
        }
        expected_output_tensors = {
            "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0],
                        [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]],
        }
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)

        input_shape = mtf.Shape([batch_dim, length_dim])
        mtf_features = {
            k: mtf.import_tf_tensor(mesh, v, input_shape)
            for k, v in input_tensors.items()
        }
        mtf_outputs = utils._dynamic_text2self(mtf_features)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        for k, v in expected_output_tensors.items():
            out = lowering.export_to_tf_tensor(mtf_outputs[k])
            actual = self.evaluate(out)
            self.assertAllEqual(actual, v)
示例#8
0
  def testDense(self, units, use_bias):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    depth_dim = mtf.Dimension("depth", units)

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
    mtf_outputs = mtf.layers.dense(mtf_inputs,
                                   output_dim=depth_dim,
                                   reduced_dims=[channels_dim],
                                   activation=mtf.relu,
                                   use_bias=use_bias)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    expected_outputs = tf.keras.layers.Dense(units=units,
                                             activation=tf.nn.relu,
                                             use_bias=use_bias)(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual, expected = self.evaluate([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
示例#9
0
    def testMaskedLocalAttention1D(self, batch, length, io_channels,
                                   kv_channels, heads, window_size):
        length_q = length
        query = tf.random_normal([batch, length_q, io_channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_q_dim = mtf.Dimension("length_q", length_q)
        io_channels_dim = mtf.Dimension("io_channels", io_channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
        mtf_outputs = mtf.layers.masked_local_attention_1d(
            mtf_query,
            kv_channels=kv_channels_dim,
            heads=heads_dim,
            window_size=window_size)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, (batch, length_q, io_channels))
示例#10
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  del features

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

  ctx = params['context']
  num_hosts = ctx.num_hosts
  host_placement_fn = ctx.tpu_host_placement_function
  device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
  tf.logging.info('device_list = %s' % device_list, )

  mesh_devices = [''] * mesh_shape.size
  mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                              mesh_devices,
                                              ctx.device_assignment)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "fft_mesh")

  with mtf.utils.outside_all_rewrites():
    fsum = benchmark_model(mesh)
  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum))

  with mtf.utils.outside_all_rewrites():
    return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
示例#11
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 1
        vocab_size = 4
        model_size = 3
        num_softmaxes = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant(np.array([[1.0, 1.0, 2.0]]) /
                                 model_size**-0.5,
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            # Embedding weights.
            (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
            # Mixture weights.
            (2, 3): [[1, 0, 0], [0, 1, 1]],
            # Context weights
            (2, 3, 3): [
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
                [[0, 0, 1], [0, 1, 0], [1, 0, 0]],
            ],
        })

        vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            num_softmaxes=num_softmaxes)

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual, = self.evaluate([actual_logits])

        expected_priors = scipy.special.softmax([1, 3])
        expected_probs_1 = scipy.special.softmax(np.tanh([1, 1, 2, 2]))
        expected_probs_2 = scipy.special.softmax(np.tanh([2, 1, 1, 1]))
        expected_probs = (expected_priors[0] * expected_probs_1 +
                          expected_priors[1] * expected_probs_2)
        expected_logits = np.log(expected_probs)

        self.assertAllClose(actual, [expected_logits])
 def convert_mtf_tensor_to_tf_tensor(self, mtf_tensor):
     """Convert an mtf.Tensor to a tf.Tensor."""
     mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                           layout={},
                                                           devices=[""])
     lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
     return lowering, lowering.export_to_tf_tensor(mtf_tensor)
示例#13
0
    def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
示例#14
0
    def testDenseReluDense(self):
        batch = 2
        channels = 3
        hidden = 5
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)
        hidden_dim = mtf.Dimension("hidden", hidden)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs,
                                                  hidden_channels=hidden_dim,
                                                  is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, inputs.shape)
示例#15
0
    def testCorr2DInput(self):
        batch = 4
        channels = 3
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tfp.stats.correlation(inputs,
                                                 sample_axis=0,
                                                 event_axis=1)
        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertEqual(actual.shape, expected.shape)
        self.assertAllClose(actual, expected)
示例#16
0
  def testMultiheadAttention(self, kv_channels, heads):
    batch = 2
    length = 8
    channels = 3
    query = tf.random_normal([batch, length, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)
    channels_dim = mtf.Dimension("channels", channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.import_tf_tensor(
        mesh, query,
        shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
    mtf_outputs = mtf.layers.multihead_attention(
        mtf_query,
        memory_antecedent=None,
        mask=None,
        kv_channels=kv_channels_dim,
        heads=heads_dim)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual = self.evaluate(actual_outputs)

    self.assertEqual(actual.shape, query.shape)
示例#17
0
  def testConv1d(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    filter_size = 3
    depth_dim = mtf.Dimension("depth", 2)
    length_dim = mtf.Dimension("length", 4)
    output_dim = mtf.Dimension("output", 2)

    x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32)
    mtf_x = mtf.import_tf_tensor(
        mesh, x, shape=mtf.Shape([length_dim, depth_dim]))

    initializer_mock = mock.MagicMock()
    initializer_mock.side_effect = initialize_by_shape({
        (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3],
                                                                [-2, 2]]]],
    })

    mtf_output = mtf.layers.conv1d(
        mtf_x,
        output_dim=output_dim,
        filter_size=filter_size,
        filter_initializer=initializer_mock)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_output = lowering.export_to_tf_tensor(mtf_output)

    self.evaluate(tf.global_variables_initializer())
    self.evaluate(lowering.copy_masters_to_slices())
    actual = self.evaluate(actual_output)

    self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def main(_):

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)

    cluster_spec = cluster.cluster_spec()
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)

    # Only he master job takes care of the graph building,
    # everyone else can just chill for now
    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    mesh_devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_node)
    ]
    print("List of devices", mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    # Build the model
    fft_err = benchmark_model(mesh)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    result = lowering.export_to_tf_tensor(fft_err)

    with tf.Session(server.target) as sess:
        start = time.time()
        err = sess.run(result)
        end = time.time()

        time.sleep(1)
        start = time.time()
        err = sess.run(result)
        end = time.time()

    print("Max absolute FFT error %f, with wall time %f" % (err,
                                                            (end - start)))
    time.sleep(1)
    exit(0)
示例#19
0
def model_fn(features, labels, mode, params):
	"""The model_fn argument for creating an Estimator."""
	global_step = tf.train.get_global_step()
	graph = mtf.Graph()
	mesh = mtf.Mesh(graph, "my_mesh")
	logits, loss = model_backbone(features, labels, mesh)
	
	variables = graph._all_variables
	for v in variables:
		logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype))
	mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape)
	# layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
	mesh_shape = mtf.convert_to_shape(mesh_shape)
	estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss])
	optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE")
	layout_rules =  mtf.convert_to_layout_rules(optimizer.solve())

	# layout_rules=[('batch', 'b1')]

	logger.info("[auto mtf search] strategy: {}".format(layout_rules))
	mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))]
	mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices)



	if mode == tf.estimator.ModeKeys.TRAIN:
		var_grads = mtf.gradients(
			[loss], [v.outputs[0] for v in graph.trainable_variables])
		optimizer = mtf.optimize.SgdOptimizer(0.01)
		# optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
		update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

	lowering = mtf.Lowering(graph, {mesh: mesh_impl})
	restore_hook = mtf.MtfRestoreHook(lowering)

	tf_logits = lowering.export_to_tf_tensor(logits)
	if mode != tf.estimator.ModeKeys.PREDICT:
		tf_loss = lowering.export_to_tf_tensor(loss)

	if mode == tf.estimator.ModeKeys.TRAIN:
		tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
		tf_update_ops.append(tf.assign_add(global_step, 1))
		train_op = tf.group(tf_update_ops)

		accuracy = tf.metrics.accuracy(
			labels=labels, predictions=tf.argmax(tf_logits, axis=1))

		# Name tensors to be logged with LoggingTensorHook.
		tf.identity(tf_loss, "cross_entropy")
		tf.identity(accuracy[1], name="train_accuracy")

		logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'})
		# profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/')
		# restore_hook must come before saver_hook
		return tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
			training_chief_hooks=[restore_hook,logging_hook])
def Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr):
    lowering = mtf.Lowering(graph, mesh_to_impl)
    out_tsr = lowering.export_to_tf_tensor(mtf_out_tsr)
    assert_op = tf.assert_equal(in_tsr, out_tsr)

    func_name = inspect.stack()[1].function
    print(f'Running test {func_name}')
    with tf.Session() as sess:
        sess.run(assert_op)
        print(f'Test {func_name} successful\n')
示例#21
0
  def test_ids_to_embedding_correctlyEmbeds(self):
    seq_len = 5
    vocab_size = 5
    model_size = 2
    gate_embedding_size = 1
    frequent_token_fraction = 0.4

    vocab_dim = mtf.Dimension('vocab', vocab_size)
    model_dim = mtf.Dimension('model', model_size)
    length_dim = mtf.Dimension('length', seq_len)

    context = mock.MagicMock()
    context.train = False

    ids = tf.constant([0, 1, 2, 3, 4], dtype=tf.int32)
    mtf_ids = mtf.import_tf_tensor(
        self.mesh, ids, shape=mtf.Shape([length_dim]))

    self.initializer_mock.side_effect = initialize_by_shape({
        # Embedding weights.
        (5, 2): list(range(10)),
        # Context weights.
        (4, 2, 2): list(range(16)),
        # Prior weights.
        (3, 1, 2): list(range(6)),
        # Prior vocab vector.
        (2, 1): list(range(2)),
        # Prior gates vector.
        (3, 2): list(range(6)),
        # Prior bias.
        (2, 3): list(range(6)),
    })

    vocab_embedding = vocab_embeddings.Mixtape(
        self.mesh,
        vocab_dim,
        output_dim=model_dim,
        variable_dtype=self.variable_dtype,
        name='embedding',
        ensemble_dim=None,
        gate_embedding_size=gate_embedding_size,
        frequent_token_fraction=frequent_token_fraction)

    mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[''])
    lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
    actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

    self.evaluate(tf.global_variables_initializer())
    self.evaluate(lowering.copy_masters_to_slices())
    actual = self.evaluate([actual_embedding])[0]

    self.assertAllClose(actual, np.reshape(list(range(10)), (5, 2)))
示例#22
0
    def test_hidden_to_logits_computesLogitsCorrectly(self):
        seq_len = 4
        vocab_size = 5
        model_size = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        embeddings = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]],
                                 dtype=tf.float32)
        mtf_embeddings = mtf.import_tf_tensor(self.mesh,
                                              embeddings,
                                              shape=mtf.Shape(
                                                  [length_dim, model_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            (2, 2): [[0, 1], [2, 0]],
            (3, 1): [[1], [2], [3]],
            (1, 2): [[1], [2]],
        })

        vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            clusters=[{
                'token_count': 2,
                'embedding_size': 2
            }, {
                'token_count': 3,
                'embedding_size': 1
            }])

        mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings,
                                                      context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_logits = lowering.export_to_tf_tensor(mtf_logits)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_logits])[0]

        self.assertAllClose(
            actual,
            model_size**-0.5 * np.array([[0, 2, 1, 2, 3], [1, 0, 2, 4, 6],
                                         [1, 2, 3, 6, 9], [1, 4, 4, 8, 12]]))
示例#23
0
def test_entmax():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    length = mtf.Dimension("tensor_length", 8)
    tensor = mtf.range(mesh, length, tf.float32)
    output = entmax(tensor)
    grad = mtf.gradients([output], [tensor])[0]
    sample = sample_categorical(output, length)

    mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    sample = lowering.export_to_tf_tensor(sample)
    grad = lowering.export_to_tf_tensor(grad)
示例#24
0
文件: demo.py 项目: EiffL/mesh_demo
def main(_):
    num_tasks = int(os.environ['SLURM_NTASKS'])
    print('num_tasks : ', num_tasks)

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver({"mesh": num_tasks},
                                                                port_base=8822,
                                                                gpus_per_node=FLAGS.gpus_per_node,
                                                                gpus_per_task=FLAGS.gpus_per_task,
                                                                tasks_per_node=FLAGS.tasks_per_node)
    cluster_spec = cluster.cluster_spec()
    print(cluster_spec)

    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)
    print(server)

    if cluster.task_id >0:
      server.join()

    # Otherwise we are the main task, let's define the devices
    devices = ["/job:mesh/task:%d/device:GPU:%d"%(i,j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_task)]
    print("List of devices", devices)

    # Defines the mesh structure
    mesh_shape = [("row", 4), ("col", 2)]
    layout_rules = [("nx_block","row"), ("ny_block","col")]

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, devices)

    # Create computational graphs
    net  = model_fn(nc=FLAGS.nc, batch_size=FLAGS.batch_size)

    # Lower mesh computation
    graph = net.graph
    mesh = net.mesh
    lowering = mtf.Lowering(graph, {mesh:mesh_impl})

    # Retrieve output of computation
    result = lowering.export_to_tf_tensor(net)

    # Perform some last processing in normal tensorflow
    out = tf.reduce_mean(result)

    with tf.Session(server.target) as sess:
        r = sess.run(out)

    print("output of computation", r)

    exit(0)
示例#25
0
  def testTopK(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]],
                         dtype=tf.float32)
    k_dim = mtf.Dimension("k", 2)
    d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32)
    reduced_dim = a_dim
    expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32)
    expected_indices = tf.constant([[5, 4], [0, 1]])
    expected_d_inputs = tf.constant([[0, 13],
                                     [0, 14],
                                     [0, 0],
                                     [0, 0],
                                     [12, 0],
                                     [11, 0]],
                                    dtype=tf.float32)

    mtf_inputs = mtf.import_fully_replicated(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_d_values = mtf.import_tf_tensor(
        mesh, d_values, shape=mtf.Shape([b_dim, k_dim]))
    mtf_values, mtf_indices = mtf.top_k(mtf_inputs,
                                        reduced_dim=reduced_dim,
                                        k_dim=k_dim,
                                        name="test_nth_smallest")
    [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values])
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_values = lowering.export_to_tf_tensor(mtf_values)
    actual_indices = lowering.export_to_tf_tensor(mtf_indices)
    actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs)
    actual_inputs = lowering.export_to_tf_tensor(mtf_inputs)
    self.assertAllEqual(self.evaluate(actual_inputs),
                        self.evaluate(inputs))
    self.assertAllEqual(self.evaluate(actual_values),
                        self.evaluate(expected_values))
    self.assertAllEqual(self.evaluate(actual_indices),
                        self.evaluate(expected_indices))
    self.assertAllEqual(self.evaluate(actual_d_inputs),
                        self.evaluate(expected_d_inputs))
示例#26
0
    def test_ids_to_embedding_correctlyEmbeds(self):
        seq_len = 6
        vocab_size = 5
        model_size = 2

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        ids = tf.constant([0, 1, 2, 3, 4, 0], dtype=tf.int32)
        mtf_ids = mtf.import_tf_tensor(self.mesh,
                                       ids,
                                       shape=mtf.Shape([length_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            (3, 2): [[0, 1], [2, 0], [-1000, -4000]],
            (3, 1): [[1], [2], [3]],
            (1, 2): [[1], [2]],
        })

        vocab_embedding = adaptive_softmax.AdaptiveSoftmaxVocabEmbedding(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            clusters=[{
                'token_count': 2,
                'embedding_size': 2
            }, {
                'token_count': 3,
                'embedding_size': 1
            }])

        mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual, = self.evaluate([actual_embedding])

        self.assertAllClose(actual,
                            [[0, 1], [2, 0], [1, 2], [2, 4], [3, 6], [0, 1]])
示例#27
0
    def test_ids_to_embedding_correctlyEmbeds(self):
        seq_len = 4
        vocab_size = 4
        model_size = 3
        num_softmaxes = 1

        vocab_dim = mtf.Dimension('vocab', vocab_size)
        model_dim = mtf.Dimension('model', model_size)
        length_dim = mtf.Dimension('length', seq_len)

        ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
        mtf_ids = mtf.import_tf_tensor(self.mesh,
                                       ids,
                                       shape=mtf.Shape([length_dim]))

        self.initializer_mock.side_effect = initialize_by_shape({
            # Embedding weights.
            (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]],
            # Mixture weights.
            (1, 3): [[1, 0, 0]],
            # Context weights
            (1, 3, 3): [
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
            ],
        })

        vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes(
            self.mesh,
            vocab_dim,
            output_dim=model_dim,
            variable_dtype=self.variable_dtype,
            name='embedding',
            ensemble_dim=None,
            num_softmaxes=num_softmaxes)

        mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[''])
        lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl})
        actual_embedding = lowering.export_to_tf_tensor(mtf_embedding)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate([actual_embedding])[0]

        self.assertAllClose(actual,
                            [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]])
示例#28
0
def main(_):

    #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]

    mesh_impl = HvdSimdMeshImpl(mtf.convert_to_shape(mesh_shape),
                                mtf.convert_to_layout_rules(layout_rules))

    # Build the model
    # Create computational graphs and some initializations
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "nbody_mesh")

    initial_conditions, mesh_final_field = lpt_prototype(
        mesh, bs=FLAGS.box_size, nc=FLAGS.nc, batch_size=FLAGS.batch_size)

    # Lower mesh computation
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    initc = lowering.export_to_tf_tensor(initial_conditions)
    result = lowering.export_to_tf_tensor(mesh_final_field)

    with tf.Session() as sess:
        start = time.time()
        a, c = sess.run([initc, result])
        end = time.time()
        ttime = (end - start)
        print('Time for ', mesh_shape, ' is : ', ttime)

    if comm.rank == 0:
        plt.figure(figsize=(9, 3))
        plt.subplot(121)
        plt.imshow(a[0].sum(axis=2))
        plt.title('Initial Conditions')

        plt.subplot(122)
        plt.imshow(c[0].sum(axis=2))
        plt.title('Mesh TensorFlow')
        plt.colorbar()
        plt.savefig("mesh_nbody_%d-row:%d-col:%d.png" %
                    (FLAGS.nc, FLAGS.nx, FLAGS.ny))
        plt.close()

    exit(0)
示例#29
0
    def testBatchNorm(self):
        batch = 2
        channels = 3
        inputs = tf.constant([[0, 1, 2], [4, 5, 6]], dtype=np.float32)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))

        mtf_outputs_0, _ = mtf.layers.batch_norm(mtf_inputs,
                                                 is_training=True,
                                                 momentum=0.95,
                                                 epsilon=1e-6,
                                                 dims_idx_start=0,
                                                 dims_idx_end=1,
                                                 name="bn0")
        mtf_outputs_1, _ = mtf.layers.batch_norm(mtf_outputs_0 * 2 + 1,
                                                 is_training=True,
                                                 momentum=0.95,
                                                 epsilon=1e-6,
                                                 dims_idx_start=0,
                                                 dims_idx_end=1,
                                                 name="bn1")

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        actual_outputs_0 = lowering.export_to_tf_tensor(mtf_outputs_0)
        actual_outputs_1 = lowering.export_to_tf_tensor(mtf_outputs_1)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        [actual_0,
         actual_1] = self.evaluate([actual_outputs_0, actual_outputs_1])

        expected = np.array([[-1, -1, -1], [1, 1, 1]])
        self.assertAllClose(actual_0, expected)
        self.assertAllClose(actual_1, expected)
示例#30
0
    def testDotProductAttention(self, batch, heads, length_q, length_kv,
                                depth_k, depth_v):
        query = tf.random_normal([batch, heads, length_q, depth_k])
        key = tf.random_normal([batch, heads, length_kv, depth_k])
        value = tf.random_normal([batch, heads, length_kv, depth_v])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        heads_dim = mtf.Dimension("heads", heads)
        length_q_dim = mtf.Dimension("length_q", length_q)
        length_kv_dim = mtf.Dimension("length_kv", length_kv)
        depth_k_dim = mtf.Dimension("depth_k", depth_k)
        depth_v_dim = mtf.Dimension("depth_v", depth_v)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim]))
        mtf_key = mtf.import_tf_tensor(
            mesh,
            key,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_k_dim]))
        mtf_value = mtf.import_tf_tensor(
            mesh,
            value,
            shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim,
                             depth_v_dim]))
        mtf_outputs = mtf.layers.dot_product_attention(mtf_query,
                                                       mtf_key,
                                                       mtf_value,
                                                       mask=None,
                                                       is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))