Exemple #1
0
  def testMinimizePeakMemoryList(self):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z')
    w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'),
                            name='W').outputs[0]
    mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z:0')
    graph.set_tensor_final('V:0')
    schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))

    # List Scheduler prefers to schedule things that free the most memory.
    # When nothing is scheduled:
    #   X frees -12 entries.
    #   Y frees -20 entries.
    # After [X] scheduled:
    #   Y frees -20 entries.
    # After [X, Y] scheduled:
    #   Z frees -60 entries.
    #   W frees -15 entries.
    # After [X, Y, W] scheduled:
    #   Z frees -28 entries.
    #   V frees -45 entries.
    # Hence the schedule should be [X, Y, W, Z, V].
    self.assertEqual(schedule, [0, 1, 3, 2, 4])
Exemple #2
0
    def testMinimizePeakMemoryList_SingleUseTensor(self):
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, 'my_mesh')
        mtf.Constant(mesh,
                     0,
                     shape=mtf.convert_to_shape('a:4'),
                     dtype=tf.int32,
                     name='X')
        y = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape('b:3'),
                         dtype=tf.int32,
                         name='Y').outputs[0]
        mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z')

        graph = graph_interface.GraphInterface(mtf_graph)
        graph.set_tensor_final('X:0')
        graph.set_tensor_final('Z:0')
        schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))
        # When nothing is scheduled:
        #   X frees -4 entries
        #   Y frees -3 entries
        # After [Y] scheduled:
        #   X frees -4 entries
        #   Z frees -3 entries
        # Hence the schedule should be [Y, Z, X].
        self.assertEqual(schedule, [1, 2, 0])
Exemple #3
0
  def _get_memory_contents(self):
    """Runs the scheduler to determine memory contents at every point in time.

    Returns:
      a list of frozenset of strings, where the ith entry describes the tensors
      in memory when executing operation i (where schedule[i] is an index into
      GetAllOperationNames()).
    """
    if self._memory_contents is not None:
      return self._memory_contents

    schedule = scheduler.minimize_peak_memory(self._graph, self._scheduler_alg)
    self._memory_contents = self._graph.compute_memory_contents_under_schedule(
        schedule)

    return self._memory_contents
Exemple #4
0
  def testReturnsTopoSort(self, scheduler_alg):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1')
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z1:0')
    graph.set_tensor_final('Z2:0')
    schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg))

    self.assertCountEqual(schedule[0:2], [0, 1])
    self.assertCountEqual(schedule[2:4], [2, 3])