コード例 #1
0
 def function_to_wrap(*args):  # pylint: disable=missing-docstring
   if len(args) != len(input_tensor_names):
     raise RuntimeError('Expected {} arguments, found {}.'.format(
         str(len(input_tensor_names)), str(len(args))))
   graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
   init_op = comp.tensorflow.initialize_op
   if init_op:
     graph_def = graph_utils.add_control_deps_for_init_op(graph_def, init_op)
   return tf.import_graph_def(
       graph_merge.uniquify_shared_names(graph_def),
       input_map=dict(zip(input_tensor_names, args)),
       return_elements=output_tensor_names)
コード例 #2
0
 def function_to_wrap(*args):  # pylint: disable=missing-docstring
     if len(args) != len(input_tensor_names):
         raise RuntimeError('Expected {} arguments, found {}.'.format(
             str(len(input_tensor_names)), str(len(args))))
     graph_def = serialization_utils.unpack_graph_def(
         comp.tensorflow.graph_def)
     init_op = comp.tensorflow.initialize_op
     init_names = [init_op] if init_op else []
     returned_elements = tf.import_graph_def(
         graph_merge.uniquify_shared_names(graph_def),
         input_map=dict(zip(input_tensor_names, args)),
         return_elements=output_tensor_names + init_names)
     if init_names:
         with tf.control_dependencies([returned_elements[-1]]):
             return [tf.identity(x) for x in returned_elements[0:-1]]
     else:
         return returned_elements
コード例 #3
0
 def _import_fn():
     return tf.import_graph_def(
         graph_merge.uniquify_shared_names(graph_def), name='')
コード例 #4
0
 def _import_fn():
     return tf.import_graph_def(
         graph_merge.uniquify_shared_names(graph_def),
         input_map=dict(list(zip(input_tensor_names, args))),
         return_elements=output_tensor_names)