Esempio n. 1
0
 def testDefault(self):
     graph, inference_graph = self._TestGraph()
     output_op_names = inference_graph_exporter.GetOutputOpNames(
         graph, inference_graph)
     self.assertEqual(
         output_op_names,
         [
             # pyformat: disable
             'inference/add_2',
             'inference/input',
             'testing/b/var',
             'testing/b/var/Initializer/random_normal',
             'testing/b/var/Initializer/random_normal/RandomStandardNormal',
             'testing/b/var/Initializer/random_normal/mean',
             'testing/b/var/Initializer/random_normal/mul',
             'testing/b/var/Initializer/random_normal/shape',
             'testing/b/var/Initializer/random_normal/stddev',
             'testing/w/var',
             'testing/w/var/Initializer/random_normal',
             'testing/w/var/Initializer/random_normal/RandomStandardNormal',
             'testing/w/var/Initializer/random_normal/mean',
             'testing/w/var/Initializer/random_normal/mul',
             'testing/w/var/Initializer/random_normal/shape',
             'testing/w/var/Initializer/random_normal/stddev',
             # pyformat: enable
         ])
 def testNoPreserveColocationNodes(self):
   graph, inference_graph = self._TestGraph()
   output_op_names = inference_graph_exporter.GetOutputOpNames(
       graph, inference_graph, preserve_colocation_nodes=False)
   self.assertEqual(output_op_names, [
       # pyformat: disable
       'inference/add_2',
       'inference/input',
       # pyformat: enable
   ])
 def testPreserveSaverRestoreNodes(self):
   graph, inference_graph = self._TestGraph()
   output_op_names = inference_graph_exporter.GetOutputOpNames(
       graph,
       inference_graph,
       preserve_colocation_nodes=False,
       preserve_saver_restore_nodes=True)
   self.assertEqual(output_op_names, [
       # pyformat: disable
       'inference/add_2',
       'inference/input',
       'save/Const',
       'save/restore_all',
       # pyformat: enable
   ])
 def testPreserveExtraOps(self):
   graph, inference_graph = self._TestGraph()
   output_op_names = inference_graph_exporter.GetOutputOpNames(
       graph,
       inference_graph,
       preserve_colocation_nodes=False,
       preserve_extra_ops=[
           'init_all_tables', 'init_all_variables', 'tpu_init_op'
       ])
   self.assertEqual(output_op_names, [
       # pyformat: disable
       'inference/add_2',
       'inference/input',
       'init_all_tables',
       'init_all_variables',
       # pyformat: enable
   ])