def _compareOriginalAndReconstructedGraphDefs(self,
                                                  sess,
                                                  fetches,
                                                  feed_dict=None,
                                                  expected_output=None):
        run_options = config_pb2.RunOptions(output_partition_graphs=True)
        run_metadata = config_pb2.RunMetadata()
        output = sess.run(fetches,
                          feed_dict=feed_dict,
                          options=run_options,
                          run_metadata=run_metadata)
        if expected_output is not None:
            self.assertAllClose(expected_output, output)
        non_debug_graph_defs = run_metadata.partition_graphs

        debug_utils.watch_graph(run_options,
                                sess.graph,
                                debug_urls=self._debug_url)
        run_metadata = config_pb2.RunMetadata()
        output = sess.run(fetches,
                          feed_dict=feed_dict,
                          options=run_options,
                          run_metadata=run_metadata)
        if expected_output is not None:
            self.assertAllClose(expected_output, output)

        dump = debug_data.DebugDumpDir(
            self._dump_dir,
            partition_graphs=run_metadata.partition_graphs,
            validate=True)
        reconstructed = dump.reconstructed_non_debug_partition_graphs()

        self.assertEqual(len(non_debug_graph_defs), len(reconstructed))
        for i, non_debug_graph_def in enumerate(non_debug_graph_defs):
            device_name = debug_graphs._infer_device_name(non_debug_graph_def)
            test_util.assert_equal_graph_def(
                self._graphDefWithoutBlacklistedNodes(
                    reconstructed[device_name]),
                self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))

            # Test debug_graphs.reconstruct_non_debug_graph_def.
            reconstructed_again = (
                debug_graphs.reconstruct_non_debug_graph_def(
                    run_metadata.partition_graphs[i]))
            test_util.assert_equal_graph_def(
                self._graphDefWithoutBlacklistedNodes(reconstructed_again),
                self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
 def test_assert_equal_graph_def(self):
     with ops.Graph().as_default() as g:
         def_empty = g.as_graph_def()
         constant_op.constant(5, name="five")
         constant_op.constant(7, name="seven")
         def_57 = g.as_graph_def()
     with ops.Graph().as_default() as g:
         constant_op.constant(7, name="seven")
         constant_op.constant(5, name="five")
         def_75 = g.as_graph_def()
     # Comparing strings is order dependent
     self.assertNotEqual(str(def_57), str(def_75))
     # assert_equal_graph_def doesn't care about order
     test_util.assert_equal_graph_def(def_57, def_75)
     # Compare two unequal graphs
     with self.assertRaisesRegexp(
             AssertionError, r"^Found unexpected node '{{node seven}}"):
         test_util.assert_equal_graph_def(def_57, def_empty)
 def test_assert_equal_graph_def(self):
   with ops.Graph().as_default() as g:
     def_empty = g.as_graph_def()
     constant_op.constant(5, name="five")
     constant_op.constant(7, name="seven")
     def_57 = g.as_graph_def()
   with ops.Graph().as_default() as g:
     constant_op.constant(7, name="seven")
     constant_op.constant(5, name="five")
     def_75 = g.as_graph_def()
   # Comparing strings is order dependent
   self.assertNotEqual(str(def_57), str(def_75))
   # assert_equal_graph_def doesn't care about order
   test_util.assert_equal_graph_def(def_57, def_75)
   # Compare two unequal graphs
   with self.assertRaisesRegexp(AssertionError,
                                r"^Found unexpected node 'seven"):
     test_util.assert_equal_graph_def(def_57, def_empty)
Exemple #4
0
 def test_assert_equal_graph_def_hash_table(self):
   def get_graph_def():
     with ops.Graph().as_default() as g:
       x = constant_op.constant([2, 9], name="x")
       keys = constant_op.constant([1, 2], name="keys")
       values = constant_op.constant([3, 4], name="values")
       default = constant_op.constant(-1, name="default")
       table = lookup_ops.StaticHashTable(
           lookup_ops.KeyValueTensorInitializer(keys, values), default)
       _ = table.lookup(x)
     return g.as_graph_def()
   def_1 = get_graph_def()
   def_2 = get_graph_def()
   # The unique shared_name of each table makes the graph unequal.
   with self.assertRaisesRegex(AssertionError, "hash_table_"):
     test_util.assert_equal_graph_def(def_1, def_2,
                                      hash_table_shared_name=False)
   # That can be ignored. (NOTE: modifies GraphDefs in-place.)
   test_util.assert_equal_graph_def(def_1, def_2,
                                    hash_table_shared_name=True)
  def _compareOriginalAndReconstructedGraphDefs(self,
                                                sess,
                                                fetches,
                                                feed_dict=None,
                                                expected_output=None):
    run_options = config_pb2.RunOptions(output_partition_graphs=True)
    run_metadata = config_pb2.RunMetadata()
    output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
                      run_metadata=run_metadata)
    if expected_output is not None:
      self.assertAllClose(expected_output, output)
    non_debug_graph_defs = run_metadata.partition_graphs

    debug_utils.watch_graph(
        run_options, sess.graph, debug_urls=self._debug_url)
    run_metadata = config_pb2.RunMetadata()
    output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
                      run_metadata=run_metadata)
    if expected_output is not None:
      self.assertAllClose(expected_output, output)

    dump = debug_data.DebugDumpDir(
        self._dump_dir, partition_graphs=run_metadata.partition_graphs,
        validate=True)
    reconstructed = dump.reconstructed_non_debug_partition_graphs()

    self.assertEqual(len(non_debug_graph_defs), len(reconstructed))
    for i, non_debug_graph_def in enumerate(non_debug_graph_defs):
      device_name = debug_graphs._infer_device_name(non_debug_graph_def)
      test_util.assert_equal_graph_def(
          self._graphDefWithoutBlacklistedNodes(reconstructed[device_name]),
          self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))

      # Test debug_graphs.reconstruct_non_debug_graph_def.
      reconstructed_again = (
          debug_graphs.reconstruct_non_debug_graph_def(
              run_metadata.partition_graphs[i]))
      test_util.assert_equal_graph_def(
          self._graphDefWithoutBlacklistedNodes(reconstructed_again),
          self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))