def testDefault(self): graph, inference_graph = self._TestGraph() output_op_names = inference_graph_exporter.GetOutputOpNames( graph, inference_graph) self.assertEqual( output_op_names, [ # pyformat: disable 'inference/add_2', 'inference/input', 'testing/b/var', 'testing/b/var/Initializer/random_normal', 'testing/b/var/Initializer/random_normal/RandomStandardNormal', 'testing/b/var/Initializer/random_normal/mean', 'testing/b/var/Initializer/random_normal/mul', 'testing/b/var/Initializer/random_normal/shape', 'testing/b/var/Initializer/random_normal/stddev', 'testing/w/var', 'testing/w/var/Initializer/random_normal', 'testing/w/var/Initializer/random_normal/RandomStandardNormal', 'testing/w/var/Initializer/random_normal/mean', 'testing/w/var/Initializer/random_normal/mul', 'testing/w/var/Initializer/random_normal/shape', 'testing/w/var/Initializer/random_normal/stddev', # pyformat: enable ])
def testNoPreserveColocationNodes(self): graph, inference_graph = self._TestGraph() output_op_names = inference_graph_exporter.GetOutputOpNames( graph, inference_graph, preserve_colocation_nodes=False) self.assertEqual(output_op_names, [ # pyformat: disable 'inference/add_2', 'inference/input', # pyformat: enable ])
def testPreserveSaverRestoreNodes(self): graph, inference_graph = self._TestGraph() output_op_names = inference_graph_exporter.GetOutputOpNames( graph, inference_graph, preserve_colocation_nodes=False, preserve_saver_restore_nodes=True) self.assertEqual(output_op_names, [ # pyformat: disable 'inference/add_2', 'inference/input', 'save/Const', 'save/restore_all', # pyformat: enable ])
def testPreserveExtraOps(self): graph, inference_graph = self._TestGraph() output_op_names = inference_graph_exporter.GetOutputOpNames( graph, inference_graph, preserve_colocation_nodes=False, preserve_extra_ops=[ 'init_all_tables', 'init_all_variables', 'tpu_init_op' ]) self.assertEqual(output_op_names, [ # pyformat: disable 'inference/add_2', 'inference/input', 'init_all_tables', 'init_all_variables', # pyformat: enable ])