Пример #1
0
  def testMinimizePeakMemoryList(self):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z')
    w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'),
                            name='W').outputs[0]
    mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z:0')
    graph.set_tensor_final('V:0')
    schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))

    # List Scheduler prefers to schedule things that free the most memory.
    # When nothing is scheduled:
    #   X frees -12 entries.
    #   Y frees -20 entries.
    # After [X] scheduled:
    #   Y frees -20 entries.
    # After [X, Y] scheduled:
    #   Z frees -60 entries.
    #   W frees -15 entries.
    # After [X, Y, W] scheduled:
    #   Z frees -28 entries.
    #   V frees -45 entries.
    # Hence the schedule should be [X, Y, W, Z, V].
    self.assertEqual(schedule, [0, 1, 3, 2, 4])
    def testMeshTensorFlowGraph(self):
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, "my_mesh")
        x = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape("a:3,b:4"),
                         dtype=tf.int32,
                         name="X").outputs[0]
        y = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape("b:4,c:5"),
                         dtype=tf.int32,
                         name="Y").outputs[0]
        mtf.EinsumOperation([x, y], mtf.convert_to_shape("a:3,c:5"), name="Z1")
        mtf.EinsumOperation([x, y], mtf.convert_to_shape("a:3,c:5"), name="Z2")

        graph = graph_interface.GraphInterface(mtf_graph)
        self.VerifyGraphInterface(graph)

        self.assertCountEqual(graph.get_operation_mtf_dimension_names("X"),
                              ["a", "b"])
        self.assertCountEqual(graph.get_operation_mtf_dimension_names("Y"),
                              ["b", "c"])
        self.assertCountEqual(graph.get_operation_mtf_dimension_names("Z1"),
                              ["a", "b", "c"])
        self.assertCountEqual(graph.get_operation_mtf_dimension_names("Z2"),
                              ["a", "b", "c"])

        self.assertCountEqual(graph.get_tensor_mtf_dimension_names("X:0"),
                              ["a", "b"])
        self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Y:0"),
                              ["b", "c"])
        self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Z1:0"),
                              ["a", "c"])
        self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Z1:0"),
                              ["a", "c"])

        self.assertIsNone(graph.get_tensor_device("X:0"))
        self.assertIsNone(graph.get_tensor_device("Y:0"))
        self.assertIsNone(graph.get_tensor_device("Z1:0"))
        self.assertIsNone(graph.get_tensor_device("Z2:0"))

        self.assertTrue(graph.is_tensor_on_canonical_device("X:0"))
        self.assertTrue(graph.is_tensor_on_canonical_device("Y:0"))
        self.assertTrue(graph.is_tensor_on_canonical_device("Z1:0"))
        self.assertTrue(graph.is_tensor_on_canonical_device("Z2:0"))

        self.assertEqual(graph.compute_cost_graph().SerializeToString(),
                         self._deviceless_cost_graph_string)
        self.assertEqual(
            graph.compute_cost_graph(devices=[]).SerializeToString(),
            self._deviceless_cost_graph_string)
Пример #3
0
 def testEinsumOperation(self):
   x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim]))
   einsum_operation = mtf.EinsumOperation([self.x, x2],
                                          mtf.Shape([self.b_dim, self.c_dim]))
   self.assertEqual(einsum_operation.splittable_dims,
                    frozenset(["a", "b", "c"]))
   self.assertEqual(einsum_operation.unsplittable_dims, frozenset())
Пример #4
0
  def testReturnsTopoSort(self, scheduler_alg):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1')
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z1:0')
    graph.set_tensor_final('Z2:0')
    schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg))

    self.assertCountEqual(schedule[0:2], [0, 1])
    self.assertCountEqual(schedule[2:4], [2, 3])
Пример #5
0
    def setUp(self):
        super(MemoryEstimatorTest, self).setUp()
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, 'lowering_context_mesh')

        a_dim = mtf.Dimension('a', 3)
        b_dim = mtf.Dimension('b', 4)
        c_dim = mtf.Dimension('c', 5)

        x = (mtf.Constant(mesh, 0, mtf.Shape([a_dim, b_dim]), tf.int32,
                          'X').outputs[0])
        y = (mtf.Constant(mesh, 0, mtf.Shape([b_dim, c_dim]), tf.int32,
                          'Y').outputs[0])
        z = (mtf.EinsumOperation([x, y], mtf.Shape([a_dim, c_dim]),
                                 name='Z').outputs[0])

        mesh_shape = mtf.Shape([('m1', 4), ('m2', 3)])

        self.estimator = memory_estimator.MemoryEstimator(
            mtf_graph, mesh_shape, [z])