def testMinimizePeakMemoryList(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z') w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='W').outputs[0] mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z:0') graph.set_tensor_final('V:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # List Scheduler prefers to schedule things that free the most memory. # When nothing is scheduled: # X frees -12 entries. # Y frees -20 entries. # After [X] scheduled: # Y frees -20 entries. # After [X, Y] scheduled: # Z frees -60 entries. # W frees -15 entries. # After [X, Y, W] scheduled: # Z frees -28 entries. # V frees -45 entries. # Hence the schedule should be [X, Y, W, Z, V]. self.assertEqual(schedule, [0, 1, 3, 2, 4])
def testMeshTensorFlowGraph(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape("a:3,b:4"), dtype=tf.int32, name="X").outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape("b:4,c:5"), dtype=tf.int32, name="Y").outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape("a:3,c:5"), name="Z1") mtf.EinsumOperation([x, y], mtf.convert_to_shape("a:3,c:5"), name="Z2") graph = graph_interface.GraphInterface(mtf_graph) self.VerifyGraphInterface(graph) self.assertCountEqual(graph.get_operation_mtf_dimension_names("X"), ["a", "b"]) self.assertCountEqual(graph.get_operation_mtf_dimension_names("Y"), ["b", "c"]) self.assertCountEqual(graph.get_operation_mtf_dimension_names("Z1"), ["a", "b", "c"]) self.assertCountEqual(graph.get_operation_mtf_dimension_names("Z2"), ["a", "b", "c"]) self.assertCountEqual(graph.get_tensor_mtf_dimension_names("X:0"), ["a", "b"]) self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Y:0"), ["b", "c"]) self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Z1:0"), ["a", "c"]) self.assertCountEqual(graph.get_tensor_mtf_dimension_names("Z1:0"), ["a", "c"]) self.assertIsNone(graph.get_tensor_device("X:0")) self.assertIsNone(graph.get_tensor_device("Y:0")) self.assertIsNone(graph.get_tensor_device("Z1:0")) self.assertIsNone(graph.get_tensor_device("Z2:0")) self.assertTrue(graph.is_tensor_on_canonical_device("X:0")) self.assertTrue(graph.is_tensor_on_canonical_device("Y:0")) self.assertTrue(graph.is_tensor_on_canonical_device("Z1:0")) self.assertTrue(graph.is_tensor_on_canonical_device("Z2:0")) self.assertEqual(graph.compute_cost_graph().SerializeToString(), self._deviceless_cost_graph_string) self.assertEqual( graph.compute_cost_graph(devices=[]).SerializeToString(), self._deviceless_cost_graph_string)
def testEinsumOperation(self): x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim])) einsum_operation = mtf.EinsumOperation([self.x, x2], mtf.Shape([self.b_dim, self.c_dim])) self.assertEqual(einsum_operation.splittable_dims, frozenset(["a", "b", "c"])) self.assertEqual(einsum_operation.unsplittable_dims, frozenset())
def testReturnsTopoSort(self, scheduler_alg): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1') mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z1:0') graph.set_tensor_final('Z2:0') schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg)) self.assertCountEqual(schedule[0:2], [0, 1]) self.assertCountEqual(schedule[2:4], [2, 3])
def setUp(self): super(MemoryEstimatorTest, self).setUp() mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'lowering_context_mesh') a_dim = mtf.Dimension('a', 3) b_dim = mtf.Dimension('b', 4) c_dim = mtf.Dimension('c', 5) x = (mtf.Constant(mesh, 0, mtf.Shape([a_dim, b_dim]), tf.int32, 'X').outputs[0]) y = (mtf.Constant(mesh, 0, mtf.Shape([b_dim, c_dim]), tf.int32, 'Y').outputs[0]) z = (mtf.EinsumOperation([x, y], mtf.Shape([a_dim, c_dim]), name='Z').outputs[0]) mesh_shape = mtf.Shape([('m1', 4), ('m2', 3)]) self.estimator = memory_estimator.MemoryEstimator( mtf_graph, mesh_shape, [z])