Ejemplo n.º 1
0
    def testFromStringHandle(self):
        test_cases = [{
            'shape': tensor_shape.TensorShape([])
        }, {
            'shape': tensor_shape.TensorShape([3])
        }, {
            'shape': tensor_shape.TensorShape([1, 2])
        }, {
            'shape': tensor_shape.TensorShape([1, 2, 3])
        }]

        for test_case in test_cases:
            with ops.Graph().as_default() as g:
                iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
                handle = iterator.string_handle()
                iterator = iterator_ops.Iterator.from_string_handle(
                    handle, dtypes.int64, output_shapes=test_case['shape'])
                get_next = iterator.get_next()
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                train_op.append(get_next)
                mg = meta_graph.create_meta_graph_def(graph=g)
                grappler_item = item.Item(mg)
                op_properties = grappler_item.GetOpProperties()
                self.assertEqual(test_case['shape'],
                                 op_properties['IteratorGetNext'][0].shape)
Ejemplo n.º 2
0
    def run(self):
        """Generates a placement graph.

        Returns:
            a tuple of placement graph and a dictionary that maps an operator
            name to a node index in the placement graph.
        """
        op_graph = nx.DiGraph()
        op_index = self._add_nodes(op_graph)
        self._add_edges(op_graph, op_index)

        op_graph, op_index = placer_utils.prune_dangling_ops(op_graph)

        if self._only_important_ops:
            metagraph = tf.train.export_meta_graph(
                graph=self._tf_graph, clear_extraneous_savers=True)
            item = gitem.Item(metagraph, ignore_colocation=False)
            important_ops = item.IdentifyImportantOps()
            _LOGGER.info("Use only important ops. # important ops=%d",
                         len(important_ops))
            op_graph, op_index = placer_utils.prune_non_important_ops(
                op_graph, important_ops)

        # assign topo order
        placer_utils.assign_topo_order(op_graph)

        return op_graph, op_index
Ejemplo n.º 3
0
    def testMap(self):
        test_cases = [{
            'tensor': 0,
            'shape': tensor_shape.TensorShape([])
        }, {
            'tensor': np.array([1, 2, 3]),
            'shape': tensor_shape.TensorShape([3])
        }, {
            'tensor': np.array([[1, 2, 3]]),
            'shape': tensor_shape.TensorShape([3, 1])
        }, {
            'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
            'shape': tensor_shape.TensorShape([3, 2, 1])
        }]

        for test_case in test_cases:
            with ops.Graph().as_default() as g:
                dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
                dataset = dataset.map(array_ops.transpose)
                iterator = dataset.make_one_shot_iterator()
                get_next = iterator.get_next()
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                train_op.append(get_next)
                mg = meta_graph.create_meta_graph_def(graph=g)
                grappler_item = item.Item(mg)
                op_properties = grappler_item.GetOpProperties()
                self.assertEqual(test_case['shape'],
                                 op_properties['IteratorGetNext'][0].shape)
def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None):
  """Analyze the peak memory usage for the provided metagraph.

  Args:
    metagraph: A TensorFlow MetaGraphDef.
    detailed_report: print the live tensors in addition to the peak memory
      usage.
    cluster: Analyze the memory using the specified cluster, or the local
      machine if no cluster was specified.

  Returns:
    A string with the formatted memory usage.
  """
  if cluster is None:
    cluster = gcluster.Cluster(
        disable_detailed_stats=True, disable_timeline=True)

  item = gitem.Item(metagraph)
  peak_usage = cluster.DeterminePeakMemoryUsage(item)
  report = ""
  for device, snapshot in peak_usage.items():
    peak_usage = snapshot[0]
    report += "Peak usage for device " + device + ": " + str(
        peak_usage) + " bytes\n"
    if detailed_report:
      live_tensors = snapshot[1]
      for tensor in live_tensors:
        op_name = tensor[0]
        output_id = tensor[1]
        mem_used = tensor[2]
        report += "  " + str(op_name) + ":" + str(output_id) + " uses " + str(
            mem_used) + " bytes\n"

  return report
Ejemplo n.º 5
0
    def testPaddedBatch(self):
        test_cases = [{
            'tensor': 0,
            'shape': tensor_shape.TensorShape([None])
        }, {
            'tensor': np.array([1, 2, 3]),
            'shape': tensor_shape.TensorShape([None, 4])
        }, {
            'tensor': np.array([[1, 2, 3]]),
            'shape': tensor_shape.TensorShape([None, 2, 4])
        }]

        for test_case in test_cases:
            with ops.Graph().as_default() as g:
                dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
                dataset = dataset.padded_batch(
                    42, padded_shapes=test_case['shape'][1:])
                iterator = dataset.make_one_shot_iterator()
                get_next = iterator.get_next()
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                train_op.append(get_next)
                mg = meta_graph.create_meta_graph_def(graph=g)
                grappler_item = item.Item(mg)
                op_properties = grappler_item.GetOpProperties()
                inferred_shape = self.as_tensor_shape(
                    op_properties['IteratorGetNext'][0].shape)
                self.assertTrue(test_case['shape'].dims[0].is_compatible_with(
                    inferred_shape[0]))
                self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
  def testVirtualCluster(self):
    with ops.Graph().as_default() as g:
      with ops.device('/device:GPU:0'):
        a = random_ops.random_uniform(shape=[1024, 1024])
        b = random_ops.random_uniform(shape=[1024, 1024])
        c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      device_properties = device_properties_pb2.DeviceProperties(
          type='GPU',
          frequency=1000,
          num_cores=60,
          environment={'architecture': '7'})
      named_device = device_properties_pb2.NamedDevice(
          properties=device_properties, name='/device:GPU:0')
      grappler_cluster = cluster.Cluster(
          disable_detailed_stats=False,
          disable_timeline=False,
          devices=[named_device])
      op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
      self.assertEqual(run_time, 0.000545)
      self.assertEqual(len(op_perfs), 15)

      estimated_perf = grappler_cluster.EstimatePerformance(named_device)
      self.assertEqual(7680.0, estimated_perf)
Ejemplo n.º 7
0
    def testUpdates(self):
        with ops.Graph().as_default() as g:
            a = constant_op.constant(10)
            b = constant_op.constant(20)
            c = a + b
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(c)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)

        initial_tf_item = grappler_item.tf_item
        no_change_tf_item = grappler_item.tf_item
        self.assertEqual(initial_tf_item, no_change_tf_item)

        # Modify the placement.
        for node in grappler_item.metagraph.graph_def.node:
            node.device = '/cpu:0'
        new_tf_item = grappler_item.tf_item
        self.assertNotEqual(initial_tf_item, new_tf_item)

        # Assign the same placement.
        for node in grappler_item.metagraph.graph_def.node:
            node.device = '/cpu:0'
        newest_tf_item = grappler_item.tf_item
        self.assertEqual(new_tf_item, newest_tf_item)
Ejemplo n.º 8
0
    def testFromGenerator(self):
        test_cases = [{
            'tensor': 0,
            'shape': tensor_shape.TensorShape([])
        }, {
            'tensor': np.array([1, 2, 3]),
            'shape': tensor_shape.TensorShape([3])
        }, {
            'tensor': np.array([[1, 2, 3]]),
            'shape': tensor_shape.TensorShape([1, 3])
        }]

        for test_case in test_cases:

            def make_generator(tensor):
                def generator():
                    yield tensor

                return generator

            with ops.Graph().as_default() as g:
                dataset = dataset_ops.Dataset.from_generator(
                    make_generator(test_case['tensor']),
                    dtypes.int64,
                    output_shapes=test_case['shape'])
                iterator = dataset.make_one_shot_iterator()
                get_next = iterator.get_next()
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                train_op.append(get_next)
                mg = meta_graph.create_meta_graph_def(graph=g)
                grappler_item = item.Item(mg)
                op_properties = grappler_item.GetOpProperties()
                self.assertEqual(test_case['shape'],
                                 op_properties['IteratorGetNext'][0].shape)
Ejemplo n.º 9
0
    def testInterleave(self):
        test_cases = [{
            'tensor': 0,
            'shape': tensor_shape.TensorShape([])
        }, {
            'tensor': np.array([1, 2, 3]),
            'shape': tensor_shape.TensorShape([3])
        }, {
            'tensor': np.array([[1, 2, 3]]),
            'shape': tensor_shape.TensorShape([1, 3])
        }]

        for test_case in test_cases:
            with ops.Graph().as_default() as g:
                dataset = dataset_ops.Dataset.range(42)

                def make_dataset(tensor):
                    def dataset_fn(n):
                        return dataset_ops.Dataset.from_tensors(tensor).repeat(
                            n)

                    return dataset_fn

                dataset = dataset.interleave(make_dataset(test_case['tensor']),
                                             cycle_length=42)
                iterator = dataset.make_one_shot_iterator()
                get_next = iterator.get_next()
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                train_op.append(get_next)
                mg = meta_graph.create_meta_graph_def(graph=g)
                grappler_item = item.Item(mg)
                op_properties = grappler_item.GetOpProperties()
                self.assertEqual(test_case['shape'],
                                 op_properties['IteratorGetNext'][0].shape)
Ejemplo n.º 10
0
    def testInvalidItem(self):
        with ops.Graph().as_default() as g:
            a = constant_op.constant(10)
            b = constant_op.constant(20)
            c = a + b  # pylint: disable=unused-variable
            mg = meta_graph.create_meta_graph_def(graph=g)

        # The train op isn't specified: this should raise an InvalidArgumentError
        # exception.
        with self.assertRaises(errors_impl.InvalidArgumentError):
            item.Item(mg)
Ejemplo n.º 11
0
 def testImportantOps(self):
     with ops.Graph().as_default() as g:
         a = constant_op.constant(10)
         b = constant_op.constant(20)
         c = a + b
         train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
         train_op.append(c)
         mg = meta_graph.create_meta_graph_def(graph=g)
         grappler_item = item.Item(mg)
         op_list = grappler_item.IdentifyImportantOps()
         self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
Ejemplo n.º 12
0
 def testRange(self):
     with ops.Graph().as_default() as g:
         dataset = dataset_ops.Dataset.range(42)
         iterator = dataset.make_one_shot_iterator()
         get_next = iterator.get_next()
         train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
         train_op.append(get_next)
         mg = meta_graph.create_meta_graph_def(graph=g)
         grappler_item = item.Item(mg)
         op_properties = grappler_item.GetOpProperties()
         self.assertEqual(tensor_shape.scalar(),
                          op_properties['IteratorGetNext'][0].shape)
 def testColocationContraints(self):
   with ops.Graph().as_default() as g:
     c = constant_op.constant([10])
     v = variables.VariableV1([3], dtype=dtypes.int32)
     i = gen_array_ops.ref_identity(v)
     a = state_ops.assign(i, c)
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(a)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     groups = grappler_item.GetColocationGroups()
     self.assertEqual(len(groups), 1)
     self.assertItemsEqual(
         groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
  def testContext(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    with cluster.Provision(
        disable_detailed_stats=False, disable_timeline=False) as gcluster:
      op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
      self.assertTrue(run_time > 0)
      self.assertEqual(len(op_perfs), 4)
      self.assertTrue(step_stats.dev_stats)
Ejemplo n.º 15
0
    def testNoDetailedStats(self):
        with ops.Graph().as_default() as g:
            a = random_ops.random_uniform(shape=())
            b = random_ops.random_uniform(shape=())
            c = a + b
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(c)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)
            grappler_cluster = cluster.Cluster(disable_detailed_stats=True)

            op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
                grappler_item)
            self.assertTrue(run_time > 0)
            self.assertEqual(len(op_perfs), 0)
            self.assertEqual(len(step_stats.dev_stats), 0)
Ejemplo n.º 16
0
    def testSupportDevices(self):
        gpu_type = test_util.gpu_device_type()
        gpu_name = test_util.gpu_device_name()
        with ops.Graph().as_default() as g:
            a = random_ops.random_uniform(shape=(2, 3))
            b = random_ops.random_uniform(shape=(2, 3))
            c = a + b
            dims = math_ops.range(0, array_ops.rank(c), 1)
            d = math_ops.reduce_sum(a, axis=dims)
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(d)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)

            device_properties = device_properties_pb2.DeviceProperties(
                type=gpu_type, frequency=1000, num_cores=60)
            named_gpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name=gpu_name)
            device_properties = device_properties_pb2.DeviceProperties(
                type='CPU', frequency=3000, num_cores=6)
            named_cpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name='/CPU:0')
            virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
            supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
            self.assertEqual(supported_dev['add'], ['/CPU:0', gpu_name])
            self.assertEqual(supported_dev['Sum'], ['/CPU:0', gpu_name])
            self.assertEqual(supported_dev['range'], ['/CPU:0', gpu_name])

            real_cluster = cluster.Cluster()
            supported_dev = real_cluster.GetSupportedDevices(grappler_item)
            if test.is_gpu_available():
                self.assertEqual(supported_dev['add'], [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0' + gpu_name
                ])
                self.assertEqual(supported_dev['Sum'], [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0' + gpu_name
                ])
                # The axis tensor must reside on the host
                self.assertEqual(
                    supported_dev['range'],
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
            else:
                self.assertEqual(
                    supported_dev['add'],
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
Ejemplo n.º 17
0
  def testOpProperties(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      op_properties = grappler_item.GetOpProperties()

      # All the nodes in this model have one scalar output
      for node in grappler_item.metagraph.graph_def.node:
        node_prop = op_properties[node.name]

        self.assertEqual(1, len(node_prop))
        self.assertEqual(dtypes.int32, node_prop[0].dtype)
        self.assertEqual(tensor_shape.scalar(), node_prop[0].shape)
Ejemplo n.º 18
0
 def testVirtualCluster(self):
   with ops.Graph().as_default() as g:
     a = random_ops.random_uniform(shape=())
     b = random_ops.random_uniform(shape=())
     c = a + b
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(c)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     device_properties = device_properties_pb2.DeviceProperties(
         type='GPU', environment={
             'architecture': '7'
         })
     named_device = device_properties_pb2.NamedDevice(
         properties=device_properties, name='/GPU:0')
     grappler_cluster = cluster.Cluster(devices=[named_device])
     op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
     self.assertGreater(run_time, 0)
     self.assertEqual(len(op_perfs), 15)
    def testLoops(self):
        g = ops.Graph()
        with g.as_default():

            def _Cond(_, counter):
                return counter < end

            def _Body(buf, counter):
                buf = array_ops.concat([buf, [counter]], 0)
                counter += 1
                return [buf, counter]

            start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
            end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
            init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
            loop_vars = [init_buf, start]
            shape_inv = [
                tensor_shape.TensorShape([None]),
                tensor_shape.TensorShape([])
            ]
            buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars,
                                                 shape_inv)

            f = -array_ops.ones_like(buf, optimize=False)  # pylint: disable=invalid-unary-operand-type
            buf_shape = array_ops.shape(buf)
            f_shape = array_ops.shape(f)
            ops.add_to_collection('train_op', buf_shape)
            ops.add_to_collection('train_op', f_shape)

        # Optimize the graph.
        mg = meta_graph.create_meta_graph_def(graph=g)
        config = config_pb2.ConfigProto()
        rewriter_config = config.graph_options.rewrite_options
        rewriter_config.min_graph_nodes = -1
        optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
        mg.graph_def.CopyFrom(optimized_graph)

        # Check that the nodes referenced in various collections have been preserved
        item = gitem.Item(mg)
        props = item.GetOpProperties()
        buf_prop = props[buf.op.name]
        f_prop = props[f.op.name]
        self.assertEqual(buf_prop, f_prop)
 def testMemoryEstimates(self):
   with ops.Graph().as_default() as g:
     with ops.device('/job:localhost/replica:0/task:0/device:CPU:0'):
       a = random_ops.random_uniform(shape=())
       b = random_ops.random_uniform(shape=())
       c = a + b
       train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
       train_op.append(c)
       mg = meta_graph.create_meta_graph_def(graph=g)
       grappler_item = item.Item(mg)
       grappler_cluster = cluster.Cluster(
           disable_detailed_stats=True, disable_timeline=True)
       peak_mem = grappler_cluster.DeterminePeakMemoryUsage(grappler_item)
       self.assertLessEqual(1, len(peak_mem))
       snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0']
       peak_usage = snapshot[0]
       self.assertEqual(52, peak_usage)
       live_tensors = snapshot[1]
       self.assertEqual(15, len(live_tensors))
Ejemplo n.º 21
0
    def testSupportDevices(self):
        with ops.Graph().as_default() as g:
            a = random_ops.random_uniform(shape=(2, 3))
            b = random_ops.random_uniform(shape=(2, 3))
            c = a + b
            dims = math_ops.range(0, array_ops.rank(c), 1)
            d = math_ops.reduce_sum(a, axis=dims)
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(d)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)

            device_properties = device_properties_pb2.DeviceProperties(
                type='GPU', frequency=1000, num_cores=60)
            named_gpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name='/GPU:0')
            device_properties = device_properties_pb2.DeviceProperties(
                type='CPU', frequency=3000, num_cores=6)
            named_cpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name='/CPU:0')
            virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
            supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
            self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0'])
            self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0'])
            self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0'])

            real_cluster = cluster.Cluster()
            supported_dev = real_cluster.GetSupportedDevices(grappler_item)
            #NCL 18.04 -- Hack to account for possible XLA devices
            if test.is_gpu_available():
                add_devices = [
                    d for d in supported_dev['add']
                    if not d.split(':')[-2].startswith('XLA')
                ]
                self.assertEqual(add_devices, [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0/device:GPU:0'
                ])
                Sum_devices = [
                    d for d in supported_dev['Sum']
                    if not d.split(':')[-2].startswith('XLA')
                ]
                self.assertEqual(Sum_devices, [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0/device:GPU:0'
                ])
                # The axis tensor must reside on the host
                range_devices = [
                    d for d in supported_dev['range']
                    if not d.split(':')[-2].startswith('XLA')
                ]
                self.assertEqual(
                    range_devices,
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
            else:
                add_devices = [
                    d for d in supported_dev['add']
                    if not d.split(':')[-2].startswith('XLA')
                ]
                self.assertEqual(
                    add_devices,
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
Ejemplo n.º 22
0
def PlaceGraph(metagraph,
               cluster=None,
               allotted_time=3600,
               hparams=None,
               verbose=False):
    """Place the provided metagraph.

  Args:
    metagraph: the metagraph to place.
    cluster: an optional set of hardware resource to optimize the placement for.
      If none is specified, we'll optimize the placement for the hardware
      available on the local machine.
    allotted_time: the maximum amount to time in seconds to spend optimizing
      the placement.
    hparams: hyperparameters used to fine tune the placer.
    verbose: prints debug information if True.

  Returns:
    The placed metagraph.
  """
    if cluster is None:
        cluster = gcluster.Cluster()

    # Optimize the metagraph to speedup the placement
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.optimizers.append("pruning")
    rewriter_config.optimizers.append("constfold")
    rewriter_config.optimizers.append("arithmetic")
    rewriter_config.optimizers.append("dependency")
    rewriter_config.optimizers.append("pruning")
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config,
                                                 metagraph,
                                                 verbose=verbose,
                                                 cluster=cluster)
    optimized_metagraph = meta_graph_pb2.MetaGraphDef()
    optimized_metagraph.CopyFrom(metagraph)
    optimized_metagraph.graph_def.CopyFrom(optimized_graph)

    item = gitem.Item(optimized_metagraph)

    # Measure the runtime achievable with the original placement.
    try:
        _, original_run_time, _ = cluster.MeasureCosts(item)
        if verbose:
            print("Runtime for original placement: " + str(original_run_time))
    except errors.OpError as e:
        if verbose:
            print("Original placement isn't feasible: " + str(e))
        original_run_time = hparams.failing_signal

    if hparams is None:
        hparams = hierarchical_controller.hierarchical_controller_hparams()
    # We run with a single child
    hparams.num_children = 1

    with tf_ops.Graph().as_default():
        # Place all the nodes of the controller on the CPU. We don't want them to
        # fight for accelerator memory with the model to optimize.
        with tf_ops.device("/device:CPU:0"):
            model = hierarchical_controller.HierarchicalController(
                hparams, item, cluster)
            ops = model.build_controller()
            session_creator = training.ChiefSessionCreator()
            with training.MonitoredSession(
                    session_creator=session_creator) as sess:
                start_time = time.time()
                current_time = start_time
                while current_time - start_time < allotted_time:
                    grouping_actions = model.generate_grouping(sess)
                    input_to_seq2seq = model.create_group_embeddings(
                        grouping_actions, verbose=verbose)
                    model.generate_placement(input_to_seq2seq, sess)
                    try:
                        run_time = model.eval_placement(sess, verbose=verbose)
                    except errors.OpError as e:
                        if verbose:
                            print("Failed to run graph:" + str(e))
                        run_time = hparams.failing_signal
                    updated = model.update_reward(sess,
                                                  run_time,
                                                  verbose=verbose)
                    if updated and run_time < original_run_time:
                        if verbose:
                            print("Found better placement, with runtime " +
                                  str(run_time))
                        model.export_placement(metagraph)

                    model.process_reward(sess)

                    current_time = time.time()

    return metagraph