def to_proto(self, export_scope=None): """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. Returns: A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ if (export_scope is None or self.queue.name.startswith(export_scope)): queue_runner_def = queue_runner_pb2.QueueRunnerDef() queue_runner_def.queue_name = ops.strip_name_scope( self.queue.name, export_scope) for enqueue_op in self.enqueue_ops: queue_runner_def.enqueue_op_name.append( ops.strip_name_scope(enqueue_op.name, export_scope)) queue_runner_def.close_op_name = ops.strip_name_scope( self.close_op.name, export_scope) queue_runner_def.cancel_op_name = ops.strip_name_scope( self.cancel_op.name, export_scope) queue_runner_def.queue_closed_exception_types.extend([ errors.error_code_from_exception_type(cls) for cls in self._queue_closed_exception_types ]) return queue_runner_def else: return None
def to_proto(self): """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. Returns: A `QueueRunnerDef` protocol buffer. """ queue_runner_def = queue_runner_pb2.QueueRunnerDef() queue_runner_def.queue_name = self.queue.name for enqueue_op in self.enqueue_ops: queue_runner_def.enqueue_op_name.append(enqueue_op.name) queue_runner_def.close_op_name = self.close_op.name queue_runner_def.cancel_op_name = self.cancel_op.name return queue_runner_def
def to_proto(self): """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. Returns: A `QueueRunnerDef` protocol buffer. """ queue_runner_def = queue_runner_pb2.QueueRunnerDef() queue_runner_def.queue_name = self.queue.name for enqueue_op in self.enqueue_ops: queue_runner_def.enqueue_op_name.append(enqueue_op.name) queue_runner_def.close_op_name = self.close_op.name queue_runner_def.cancel_op_name = self.cancel_op.name queue_runner_def.queue_closed_exception_types.extend([ errors.error_code_from_exception_type(cls) for cls in self._queue_closed_exception_types]) return queue_runner_def
def testAddCollectionDef(self): test_dir = self._TestDir("good_collection") filename = os.path.join(test_dir, "metafile") with self.test_session(): # Creates a graph. v0 = tf.Variable(10.0, name="v0") var = tf.Variable(tf.constant(0, dtype=tf.int64)) count_up_to = var.count_up_to(3) input_queue = tf.FIFOQueue(30, tf.float32, shared_name="collection_queue") qr = tf.train.QueueRunner(input_queue, [count_up_to]) tf.initialize_all_variables() # Creates a saver. save = tf.train.Saver({"v0": v0}) # Adds a set of collections. tf.add_to_collection("int_collection", 3) tf.add_to_collection("float_collection", 3.5) tf.add_to_collection("string_collection", "hello") tf.add_to_collection("variable_collection", v0) # Add QueueRunners. tf.train.add_queue_runner(qr) # Adds user_defined proto in three formats: string, bytes and Any. queue_runner = queue_runner_pb2.QueueRunnerDef( queue_name="test_queue") tf.add_to_collection("user_defined_string_collection", str(queue_runner)) tf.add_to_collection("user_defined_bytes_collection", queue_runner.SerializeToString()) any_buf = Any() any_buf.Pack(queue_runner) tf.add_to_collection("user_defined_any_collection", any_buf) # Generates MetaGraphDef. meta_graph_def = save.export_meta_graph(filename) self.assertTrue(meta_graph_def.HasField("saver_def")) self.assertTrue(meta_graph_def.HasField("graph_def")) collection_def = meta_graph_def.collection_def self.assertEqual(len(collection_def), 10) with tf.Graph().as_default(): # Restores from MetaGraphDef. new_saver = tf.train.import_meta_graph(filename) # Generates a new MetaGraphDef. new_meta_graph_def = new_saver.export_meta_graph() # It should be the same as the original. self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
def load_graph_session(self, read_graph_fn): meta_path = None for f in os.listdir(self._graph_path): if f.endswith(".meta"): meta_path = os.path.join(self._graph_path, f) if not meta_path: raise ValueError("Could not find meta file in: {0}".format( self._graph_path)) with tf.Session(graph=tf.Graph()) as sess: # Load the checkpoint saver = tf.train.import_meta_graph(meta_path) saver.restore(sess, tf.train.latest_checkpoint(self._graph_path)) self._meta_graph_def = saver.export_meta_graph() # Get the output node names self._output_node_names = self.get_output_node_names_from_meta_graph_def( self._meta_graph_def) if not len(self._output_node_names): self._output_node_names.update( self._meta_graph_def.collection_def[ tf.GraphKeys.TRAIN_OP].node_list.value) assert (len(self._output_node_names) > 0) self._output_tensor_names = [] # Assume inputs are placeholders that the output nodes depend on subgraph = tf.graph_util.extract_sub_graph( sess.graph_def, list(self._output_node_names)) self._input_tensor_names = self._get_placeholder_tensors_from_subgraph( subgraph) # If we did not find any placeholders, try to get placeholders that the # queue runner depends on if not len(self._input_tensor_names): for serialized_proto in self._meta_graph_def.collection_def[ tf.GraphKeys.QUEUE_RUNNERS].bytes_list.value: # Deserialize the queue runner proto qr_proto = queue_runner_pb2.QueueRunnerDef() qr_proto.ParseFromString(serialized_proto) # Try to find dequeue ops dequeue_ops = [] for node in sess.graph_def.node: for inp_name in node.input: inp_name, _, _ = self.get_node_name_and_output_index( inp_name) if (inp_name == qr_proto.queue_name and "dequeue" in node.name.lower()): dequeue_ops.append( sess.graph.get_operation_by_name( node.name)) # Input tensor names for the graphs will be output tensors of the # dequeue ops for dequeue_op in dequeue_ops: for out_ten in dequeue_op.outputs: self._input_tensor_names.append(out_ten.name) self._required_nodes.add(dequeue_op.node_def.name) # Just print a warning if we could not find any inputs if not len(self._input_tensor_names): logging.warning("Did not find any input tensors in the graph") read_graph_fn(sess)