def _compareOriginalAndReconstructedGraphDefs(self, sess, fetches, feed_dict=None, expected_output=None): run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() output = sess.run(fetches, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) if expected_output is not None: self.assertAllClose(expected_output, output) non_debug_graph_defs = run_metadata.partition_graphs debug_utils.watch_graph(run_options, sess.graph, debug_urls=self._debug_url) run_metadata = config_pb2.RunMetadata() output = sess.run(fetches, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) if expected_output is not None: self.assertAllClose(expected_output, output) dump = debug_data.DebugDumpDir( self._dump_dir, partition_graphs=run_metadata.partition_graphs, validate=True) reconstructed = dump.reconstructed_non_debug_partition_graphs() self.assertEqual(len(non_debug_graph_defs), len(reconstructed)) for i, non_debug_graph_def in enumerate(non_debug_graph_defs): device_name = debug_graphs._infer_device_name(non_debug_graph_def) test_util.assert_equal_graph_def( self._graphDefWithoutBlacklistedNodes( reconstructed[device_name]), self._graphDefWithoutBlacklistedNodes(non_debug_graph_def)) # Test debug_graphs.reconstruct_non_debug_graph_def. reconstructed_again = ( debug_graphs.reconstruct_non_debug_graph_def( run_metadata.partition_graphs[i])) test_util.assert_equal_graph_def( self._graphDefWithoutBlacklistedNodes(reconstructed_again), self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
def test_assert_equal_graph_def(self): with ops.Graph().as_default() as g: def_empty = g.as_graph_def() constant_op.constant(5, name="five") constant_op.constant(7, name="seven") def_57 = g.as_graph_def() with ops.Graph().as_default() as g: constant_op.constant(7, name="seven") constant_op.constant(5, name="five") def_75 = g.as_graph_def() # Comparing strings is order dependent self.assertNotEqual(str(def_57), str(def_75)) # assert_equal_graph_def doesn't care about order test_util.assert_equal_graph_def(def_57, def_75) # Compare two unequal graphs with self.assertRaisesRegexp( AssertionError, r"^Found unexpected node '{{node seven}}"): test_util.assert_equal_graph_def(def_57, def_empty)
def test_assert_equal_graph_def(self): with ops.Graph().as_default() as g: def_empty = g.as_graph_def() constant_op.constant(5, name="five") constant_op.constant(7, name="seven") def_57 = g.as_graph_def() with ops.Graph().as_default() as g: constant_op.constant(7, name="seven") constant_op.constant(5, name="five") def_75 = g.as_graph_def() # Comparing strings is order dependent self.assertNotEqual(str(def_57), str(def_75)) # assert_equal_graph_def doesn't care about order test_util.assert_equal_graph_def(def_57, def_75) # Compare two unequal graphs with self.assertRaisesRegexp(AssertionError, r"^Found unexpected node 'seven"): test_util.assert_equal_graph_def(def_57, def_empty)
def test_assert_equal_graph_def_hash_table(self): def get_graph_def(): with ops.Graph().as_default() as g: x = constant_op.constant([2, 9], name="x") keys = constant_op.constant([1, 2], name="keys") values = constant_op.constant([3, 4], name="values") default = constant_op.constant(-1, name="default") table = lookup_ops.StaticHashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default) _ = table.lookup(x) return g.as_graph_def() def_1 = get_graph_def() def_2 = get_graph_def() # The unique shared_name of each table makes the graph unequal. with self.assertRaisesRegex(AssertionError, "hash_table_"): test_util.assert_equal_graph_def(def_1, def_2, hash_table_shared_name=False) # That can be ignored. (NOTE: modifies GraphDefs in-place.) test_util.assert_equal_graph_def(def_1, def_2, hash_table_shared_name=True)
def _compareOriginalAndReconstructedGraphDefs(self, sess, fetches, feed_dict=None, expected_output=None): run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() output = sess.run(fetches, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) if expected_output is not None: self.assertAllClose(expected_output, output) non_debug_graph_defs = run_metadata.partition_graphs debug_utils.watch_graph( run_options, sess.graph, debug_urls=self._debug_url) run_metadata = config_pb2.RunMetadata() output = sess.run(fetches, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) if expected_output is not None: self.assertAllClose(expected_output, output) dump = debug_data.DebugDumpDir( self._dump_dir, partition_graphs=run_metadata.partition_graphs, validate=True) reconstructed = dump.reconstructed_non_debug_partition_graphs() self.assertEqual(len(non_debug_graph_defs), len(reconstructed)) for i, non_debug_graph_def in enumerate(non_debug_graph_defs): device_name = debug_graphs._infer_device_name(non_debug_graph_def) test_util.assert_equal_graph_def( self._graphDefWithoutBlacklistedNodes(reconstructed[device_name]), self._graphDefWithoutBlacklistedNodes(non_debug_graph_def)) # Test debug_graphs.reconstruct_non_debug_graph_def. reconstructed_again = ( debug_graphs.reconstruct_non_debug_graph_def( run_metadata.partition_graphs[i])) test_util.assert_equal_graph_def( self._graphDefWithoutBlacklistedNodes(reconstructed_again), self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))