Exemple #1
0
  def create_graph_mesh_and_mesh_impl(self):
    """Creates mtf graph, mesh, and mesh impl.

    This function can be called inside model_fn, which might be tpu_rewritten.

    Returns:
      graph, mesh, mesh_impl
    """

    if self._use_tpu:
      assert self._d_assignment
      graph = mtf.Graph()

      # Worker 0 caches all the TPU binaries.
      replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
      worker0_mem = replica_cache_size * 8 * self._num_hosts
      devices_memory_usage = [worker0_mem] + [0] * (self._num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(self._cpu_devices,
                                                    devices_memory_usage)
      mesh = mtf.Mesh(graph, 'my_mesh', var_placer)
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          self._mesh_shape, self._layout_rules, None, self._d_assignment)
      return graph, mesh, mesh_impl

    else:
      graph = mtf.Graph()
      mesh = mtf.Mesh(graph, 'my_mesh', None)
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          self._mesh_shape, self._layout_rules, self._gpu_devices)
      return graph, mesh, mesh_impl
def Replication2(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([1, 4]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape)
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Split3(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def NoConcatSplit(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \
            mesh1:GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1,
                                                mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Replication5(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 1], [0, 4]), \
            mesh1:GetMeshImpl([2, 4])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([('axis0', shape[0])] + shape[1:])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1,
                                               mtf_shape.dimension_names)
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Transpose1(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]),
                          *shape[2:]])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), RandName(), 'axis0', 'axis1'])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Contract2(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 4]), \
            mesh1:GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape)
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1',
                            RandName()])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def MoreDevices(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2]), \
            mesh1:GetMeshImpl([8])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:-1] + [('axis0', shape[-1])])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), 'axis0', RandName(),
         RandName()])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
Exemple #9
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
Exemple #10
0
    def testLayout(self):
        # Construct a Mesh TensorFlow graph and mesh.
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, "my_mesh")
        x = mtf.zeros(mesh, "a:10,b:5")
        y = mtf.zeros(mesh, "b:5,c:20")
        z = mtf.einsum([x, y], "a:10,c:20")

        # Decide on a mesh shape.
        mesh_shape = mtf.convert_to_shape("m1:4,m2:2")

        # Compute a layout based on the graph and mesh.
        # Note that knowing the identity of the outputs is important to the
        # optimization since they cannot be freed.
        layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z])

        a_dim = mtf.convert_to_dimension(("a", 10))
        b_dim = mtf.convert_to_dimension(("b", 5))
        c_dim = mtf.convert_to_dimension(("c", 20))

        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
Exemple #11
0
    def testMinimizePeakMemoryList_SingleUseTensor(self):
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, 'my_mesh')
        mtf.Constant(mesh,
                     0,
                     shape=mtf.convert_to_shape('a:4'),
                     dtype=tf.int32,
                     name='X')
        y = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape('b:3'),
                         dtype=tf.int32,
                         name='Y').outputs[0]
        mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z')

        graph = graph_interface.GraphInterface(mtf_graph)
        graph.set_tensor_final('X:0')
        graph.set_tensor_final('Z:0')
        schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))
        # When nothing is scheduled:
        #   X frees -4 entries
        #   Y frees -3 entries
        # After [Y] scheduled:
        #   X frees -4 entries
        #   Z frees -3 entries
        # Hence the schedule should be [Y, Z, X].
        self.assertEqual(schedule, [1, 2, 0])
Exemple #12
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'num_heads:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                batch_dim = mtf.Dimension('batch', batch_size)
                seq_dim = mtf.Dimension('seq', seq_length)

                input_ids = tf.random.uniform((batch_size, seq_length),
                                              minval=0,
                                              maxval=vocab_size,
                                              dtype=tf.int32)
                mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                                     [batch_dim, seq_dim])

                model = bert_lib.BertModel(config=bert_config,
                                           is_training=True,
                                           input_ids=mtf_input_ids,
                                           input_mask=None,
                                           token_type_ids=None)
                pooled = model.get_pooled_output()
                lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                return lowering.export_to_tf_tensor(pooled)
Exemple #13
0
  def testNthSmallestReduceSecondDim(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]])
    n = 0  # find smallest element (n is zero-indexed)
    reduced_dim = b_dim
    expected_outputs = tf.constant([1, 2, 3, 4, 5, 5])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_outputs = mtf.nth_smallest_element(
        mtf_inputs, n, reduced_dim, "test_nth_smallest")
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="all:2", layout="a:all", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
    self.assertAllEqual(self.evaluate(actual_outputs),
                        self.evaluate(expected_outputs))
Exemple #14
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
  def testLayoutAndMeshShape(self):
    # Same as previous test, but don't specify a 4x2 mesh.
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, "my_mesh")
    x = mtf.zeros(mesh, "a:10,b:5")
    y = mtf.zeros(mesh, "b:5,c:20")
    z = mtf.einsum([x, y], "a:10,c:20")

    layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(mtf_graph, 8, [z])

    a_dim = mtf.convert_to_dimension(("a", 10))
    b_dim = mtf.convert_to_dimension(("b", 5))
    c_dim = mtf.convert_to_dimension(("c", 20))

    self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
    self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)

    self.assertCountEqual(mesh_shape.dims,
                          [mtf.Dimension("mesh_0", 4),
                           mtf.Dimension("mesh_1", 2)])

    layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(
        mtf_graph, 8, [z], 1)

    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape))
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape))

    self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
Exemple #16
0
    def testDenseReluDense(self):
        batch = 2
        channels = 3
        hidden = 5
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)
        hidden_dim = mtf.Dimension("hidden", hidden)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs,
                                                  hidden_channels=hidden_dim,
                                                  is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, inputs.shape)
Exemple #17
0
    def testCorr2DInput(self):
        batch = 4
        channels = 3
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tfp.stats.correlation(inputs,
                                                 sample_axis=0,
                                                 event_axis=1)
        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertEqual(actual.shape, expected.shape)
        self.assertAllClose(actual, expected)
Exemple #18
0
 def testGraph(self):
     graph = mtf.Graph()
     self.assertLen(graph.operations, 0)
     self.assertLen(graph.tensors, 0)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     mesh = mtf.Mesh(graph, "mesh_test")
     _ = mtf.import_tf_tensor(mesh,
                              tf_tensor=tf.constant(0.),
                              shape=mtf.Shape([]))
     self.assertLen(graph.operations, 1)
     self.assertLen(graph.tensors, 1)
     self.assertLen(graph.trainable_variables, 0)
     self.assertLen(graph.all_variables, 0)
     _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
     self.assertLen(graph.operations, 2)
     self.assertLen(graph.tensors, 2)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 1)
     _ = mtf.get_variable(mesh,
                          "variable_1",
                          mtf.Shape([]),
                          trainable=False)
     self.assertLen(graph.operations, 3)
     self.assertLen(graph.tensors, 3)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 2)
Exemple #19
0
    def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  del features

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

  ctx = params['context']
  num_hosts = ctx.num_hosts
  host_placement_fn = ctx.tpu_host_placement_function
  device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
  tf.logging.info('device_list = %s' % device_list, )

  mesh_devices = [''] * mesh_shape.size
  mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                              mesh_devices,
                                              ctx.device_assignment)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "fft_mesh")

  with mtf.utils.outside_all_rewrites():
    fsum = benchmark_model(mesh)
  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum))

  with mtf.utils.outside_all_rewrites():
    return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
Exemple #21
0
    def testMaskedLocalAttention1D(self, batch, length, io_channels,
                                   kv_channels, heads, window_size):
        length_q = length
        query = tf.random_normal([batch, length_q, io_channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_q_dim = mtf.Dimension("length_q", length_q)
        io_channels_dim = mtf.Dimension("io_channels", io_channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
        mtf_outputs = mtf.layers.masked_local_attention_1d(
            mtf_query,
            kv_channels=kv_channels_dim,
            heads=heads_dim,
            window_size=window_size)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, (batch, length_q, io_channels))
Exemple #22
0
    def testDynamicText2self_unpacked(self):
        batch = 2
        length = 5
        input_tensors = {
            "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]],
            "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]],
        }
        expected_output_tensors = {
            "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0],
                        [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]],
        }
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)

        input_shape = mtf.Shape([batch_dim, length_dim])
        mtf_features = {
            k: mtf.import_tf_tensor(mesh, v, input_shape)
            for k, v in input_tensors.items()
        }
        mtf_outputs = utils._dynamic_text2self(mtf_features)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        for k, v in expected_output_tensors.items():
            out = lowering.export_to_tf_tensor(mtf_outputs[k])
            actual = self.evaluate(out)
            self.assertAllEqual(actual, v)
Exemple #23
0
  def testMinimizePeakMemoryList(self):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z')
    w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'),
                            name='W').outputs[0]
    mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z:0')
    graph.set_tensor_final('V:0')
    schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))

    # List Scheduler prefers to schedule things that free the most memory.
    # When nothing is scheduled:
    #   X frees -12 entries.
    #   Y frees -20 entries.
    # After [X] scheduled:
    #   Y frees -20 entries.
    # After [X, Y] scheduled:
    #   Z frees -60 entries.
    #   W frees -15 entries.
    # After [X, Y, W] scheduled:
    #   Z frees -28 entries.
    #   V frees -45 entries.
    # Hence the schedule should be [X, Y, W, Z, V].
    self.assertEqual(schedule, [0, 1, 3, 2, 4])
Exemple #24
0
  def testDense(self, units, use_bias):
    batch = 2
    channels = 3
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    depth_dim = mtf.Dimension("depth", units)

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
    mtf_outputs = mtf.layers.dense(mtf_inputs,
                                   output_dim=depth_dim,
                                   reduced_dims=[channels_dim],
                                   activation=mtf.relu,
                                   use_bias=use_bias)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    expected_outputs = tf.keras.layers.Dense(units=units,
                                             activation=tf.nn.relu,
                                             use_bias=use_bias)(inputs)
    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual, expected = self.evaluate([actual_outputs, expected_outputs])

    self.assertEqual(actual.shape, expected.shape)
Exemple #25
0
  def testMultiheadAttention(self, kv_channels, heads):
    batch = 2
    length = 8
    channels = 3
    query = tf.random_normal([batch, length, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)
    channels_dim = mtf.Dimension("channels", channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.import_tf_tensor(
        mesh, query,
        shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
    mtf_outputs = mtf.layers.multihead_attention(
        mtf_query,
        memory_antecedent=None,
        mask=None,
        kv_channels=kv_channels_dim,
        heads=heads_dim)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual = self.evaluate(actual_outputs)

    self.assertEqual(actual.shape, query.shape)
Exemple #26
0
    def testRecomputeGrad(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        # let's differentiate x^2 + x
        # dy/dx = 2x+1
        def x_squared_plus_x(x):
            return x * x + x

        x = tf.constant([5, 10], dtype=tf.float32)
        dy = tf.constant([2, 3], dtype=tf.float32)
        two = mtf.Dimension("two", 2)
        expected_y = tf.constant([30, 110], dtype=tf.float32)
        expected_dx = tf.constant([22, 63], dtype=tf.float32)
        mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two]))
        mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two]))
        mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
        [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            shape="processors:2", layout="two:processors", devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_y = lowering.export_to_tf_tensor(mtf_y)
        actual_dx = lowering.export_to_tf_tensor(mtf_dx)
        self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y))
        self.assertAllEqual(self.evaluate(actual_dx),
                            self.evaluate(expected_dx))
Exemple #27
0
  def testConv1d(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    filter_size = 3
    depth_dim = mtf.Dimension("depth", 2)
    length_dim = mtf.Dimension("length", 4)
    output_dim = mtf.Dimension("output", 2)

    x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32)
    mtf_x = mtf.import_tf_tensor(
        mesh, x, shape=mtf.Shape([length_dim, depth_dim]))

    initializer_mock = mock.MagicMock()
    initializer_mock.side_effect = initialize_by_shape({
        (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3],
                                                                [-2, 2]]]],
    })

    mtf_output = mtf.layers.conv1d(
        mtf_x,
        output_dim=output_dim,
        filter_size=filter_size,
        filter_initializer=initializer_mock)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_output = lowering.export_to_tf_tensor(mtf_output)

    self.evaluate(tf.global_variables_initializer())
    self.evaluate(lowering.copy_masters_to_slices())
    actual = self.evaluate(actual_output)

    self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def main(_):

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)

    cluster_spec = cluster.cluster_spec()
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)

    # Only he master job takes care of the graph building,
    # everyone else can just chill for now
    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    mesh_devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_node)
    ]
    print("List of devices", mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    # Build the model
    fft_err = benchmark_model(mesh)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    result = lowering.export_to_tf_tensor(fft_err)

    with tf.Session(server.target) as sess:
        start = time.time()
        err = sess.run(result)
        end = time.time()

        time.sleep(1)
        start = time.time()
        err = sess.run(result)
        end = time.time()

    print("Max absolute FFT error %f, with wall time %f" % (err,
                                                            (end - start)))
    time.sleep(1)
    exit(0)
Exemple #29
0
def model_fn(features, labels, mode, params):
	"""The model_fn argument for creating an Estimator."""
	global_step = tf.train.get_global_step()
	graph = mtf.Graph()
	mesh = mtf.Mesh(graph, "my_mesh")
	logits, loss = model_backbone(features, labels, mesh)
	
	variables = graph._all_variables
	for v in variables:
		logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype))
	mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape)
	# layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
	mesh_shape = mtf.convert_to_shape(mesh_shape)
	estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss])
	optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE")
	layout_rules =  mtf.convert_to_layout_rules(optimizer.solve())

	# layout_rules=[('batch', 'b1')]

	logger.info("[auto mtf search] strategy: {}".format(layout_rules))
	mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))]
	mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices)



	if mode == tf.estimator.ModeKeys.TRAIN:
		var_grads = mtf.gradients(
			[loss], [v.outputs[0] for v in graph.trainable_variables])
		optimizer = mtf.optimize.SgdOptimizer(0.01)
		# optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
		update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

	lowering = mtf.Lowering(graph, {mesh: mesh_impl})
	restore_hook = mtf.MtfRestoreHook(lowering)

	tf_logits = lowering.export_to_tf_tensor(logits)
	if mode != tf.estimator.ModeKeys.PREDICT:
		tf_loss = lowering.export_to_tf_tensor(loss)

	if mode == tf.estimator.ModeKeys.TRAIN:
		tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
		tf_update_ops.append(tf.assign_add(global_step, 1))
		train_op = tf.group(tf_update_ops)

		accuracy = tf.metrics.accuracy(
			labels=labels, predictions=tf.argmax(tf_logits, axis=1))

		# Name tensors to be logged with LoggingTensorHook.
		tf.identity(tf_loss, "cross_entropy")
		tf.identity(accuracy[1], name="train_accuracy")

		logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'})
		# profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/')
		# restore_hook must come before saver_hook
		return tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
			training_chief_hooks=[restore_hook,logging_hook])
def get_placement_mesh(hparams):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)

    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, hparams.layout, mesh_devices)
    return mesh, mesh_impl