def testMinimizePeakMemoryList(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z') w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='W').outputs[0] mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z:0') graph.set_tensor_final('V:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # List Scheduler prefers to schedule things that free the most memory. # When nothing is scheduled: # X frees -12 entries. # Y frees -20 entries. # After [X] scheduled: # Y frees -20 entries. # After [X, Y] scheduled: # Z frees -60 entries. # W frees -15 entries. # After [X, Y, W] scheduled: # Z frees -28 entries. # V frees -45 entries. # Hence the schedule should be [X, Y, W, Z, V]. self.assertEqual(schedule, [0, 1, 3, 2, 4])
def testMinimizePeakMemoryList_SingleUseTensor(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32, name='X') y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32, name='Y').outputs[0] mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('X:0') graph.set_tensor_final('Z:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # When nothing is scheduled: # X frees -4 entries # Y frees -3 entries # After [Y] scheduled: # X frees -4 entries # Z frees -3 entries # Hence the schedule should be [Y, Z, X]. self.assertEqual(schedule, [1, 2, 0])
def _get_memory_contents(self): """Runs the scheduler to determine memory contents at every point in time. Returns: a list of frozenset of strings, where the ith entry describes the tensors in memory when executing operation i (where schedule[i] is an index into GetAllOperationNames()). """ if self._memory_contents is not None: return self._memory_contents schedule = scheduler.minimize_peak_memory(self._graph, self._scheduler_alg) self._memory_contents = self._graph.compute_memory_contents_under_schedule( schedule) return self._memory_contents
def testReturnsTopoSort(self, scheduler_alg): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1') mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z1:0') graph.set_tensor_final('Z2:0') schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg)) self.assertCountEqual(schedule[0:2], [0, 1]) self.assertCountEqual(schedule[2:4], [2, 3])