Ejemplo n.º 1
0
    def to_proto(self, export_scope=None):
        """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
      the specified name scope.
    """
        if (export_scope is None or self.queue.name.startswith(export_scope)):
            queue_runner_def = queue_runner_pb2.QueueRunnerDef()
            queue_runner_def.queue_name = ops.strip_name_scope(
                self.queue.name, export_scope)
            for enqueue_op in self.enqueue_ops:
                queue_runner_def.enqueue_op_name.append(
                    ops.strip_name_scope(enqueue_op.name, export_scope))
            queue_runner_def.close_op_name = ops.strip_name_scope(
                self.close_op.name, export_scope)
            queue_runner_def.cancel_op_name = ops.strip_name_scope(
                self.cancel_op.name, export_scope)
            queue_runner_def.queue_closed_exception_types.extend([
                errors.error_code_from_exception_type(cls)
                for cls in self._queue_closed_exception_types
            ])
            return queue_runner_def
        else:
            return None
Ejemplo n.º 2
0
    def to_proto(self):
        """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.

    Returns:
      A `QueueRunnerDef` protocol buffer.
    """
        queue_runner_def = queue_runner_pb2.QueueRunnerDef()
        queue_runner_def.queue_name = self.queue.name
        for enqueue_op in self.enqueue_ops:
            queue_runner_def.enqueue_op_name.append(enqueue_op.name)
        queue_runner_def.close_op_name = self.close_op.name
        queue_runner_def.cancel_op_name = self.cancel_op.name
        return queue_runner_def
Ejemplo n.º 3
0
  def to_proto(self):
    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.

    Returns:
      A `QueueRunnerDef` protocol buffer.
    """
    queue_runner_def = queue_runner_pb2.QueueRunnerDef()
    queue_runner_def.queue_name = self.queue.name
    for enqueue_op in self.enqueue_ops:
      queue_runner_def.enqueue_op_name.append(enqueue_op.name)
    queue_runner_def.close_op_name = self.close_op.name
    queue_runner_def.cancel_op_name = self.cancel_op.name
    queue_runner_def.queue_closed_exception_types.extend([
        errors.error_code_from_exception_type(cls)
        for cls in self._queue_closed_exception_types])
    return queue_runner_def
Ejemplo n.º 4
0
    def testAddCollectionDef(self):
        test_dir = self._TestDir("good_collection")
        filename = os.path.join(test_dir, "metafile")
        with self.test_session():
            # Creates a graph.
            v0 = tf.Variable(10.0, name="v0")
            var = tf.Variable(tf.constant(0, dtype=tf.int64))
            count_up_to = var.count_up_to(3)
            input_queue = tf.FIFOQueue(30,
                                       tf.float32,
                                       shared_name="collection_queue")
            qr = tf.train.QueueRunner(input_queue, [count_up_to])
            tf.initialize_all_variables()
            # Creates a saver.
            save = tf.train.Saver({"v0": v0})
            # Adds a set of collections.
            tf.add_to_collection("int_collection", 3)
            tf.add_to_collection("float_collection", 3.5)
            tf.add_to_collection("string_collection", "hello")
            tf.add_to_collection("variable_collection", v0)
            # Add QueueRunners.
            tf.train.add_queue_runner(qr)
            # Adds user_defined proto in three formats: string, bytes and Any.
            queue_runner = queue_runner_pb2.QueueRunnerDef(
                queue_name="test_queue")
            tf.add_to_collection("user_defined_string_collection",
                                 str(queue_runner))
            tf.add_to_collection("user_defined_bytes_collection",
                                 queue_runner.SerializeToString())
            any_buf = Any()
            any_buf.Pack(queue_runner)
            tf.add_to_collection("user_defined_any_collection", any_buf)

            # Generates MetaGraphDef.
            meta_graph_def = save.export_meta_graph(filename)
            self.assertTrue(meta_graph_def.HasField("saver_def"))
            self.assertTrue(meta_graph_def.HasField("graph_def"))
            collection_def = meta_graph_def.collection_def
            self.assertEqual(len(collection_def), 10)

        with tf.Graph().as_default():
            # Restores from MetaGraphDef.
            new_saver = tf.train.import_meta_graph(filename)
            # Generates a new MetaGraphDef.
            new_meta_graph_def = new_saver.export_meta_graph()
            # It should be the same as the original.
            self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
Ejemplo n.º 5
0
    def load_graph_session(self, read_graph_fn):
        meta_path = None
        for f in os.listdir(self._graph_path):
            if f.endswith(".meta"):
                meta_path = os.path.join(self._graph_path, f)
        if not meta_path:
            raise ValueError("Could not find meta file in: {0}".format(
                self._graph_path))

        with tf.Session(graph=tf.Graph()) as sess:
            # Load the checkpoint
            saver = tf.train.import_meta_graph(meta_path)
            saver.restore(sess, tf.train.latest_checkpoint(self._graph_path))
            self._meta_graph_def = saver.export_meta_graph()

            # Get the output node names
            self._output_node_names = self.get_output_node_names_from_meta_graph_def(
                self._meta_graph_def)
            if not len(self._output_node_names):
                self._output_node_names.update(
                    self._meta_graph_def.collection_def[
                        tf.GraphKeys.TRAIN_OP].node_list.value)

            assert (len(self._output_node_names) > 0)
            self._output_tensor_names = []

            # Assume inputs are placeholders that the output nodes depend on
            subgraph = tf.graph_util.extract_sub_graph(
                sess.graph_def, list(self._output_node_names))
            self._input_tensor_names = self._get_placeholder_tensors_from_subgraph(
                subgraph)

            # If we did not find any placeholders, try to get placeholders that the
            # queue runner depends on
            if not len(self._input_tensor_names):
                for serialized_proto in self._meta_graph_def.collection_def[
                        tf.GraphKeys.QUEUE_RUNNERS].bytes_list.value:
                    # Deserialize the queue runner proto
                    qr_proto = queue_runner_pb2.QueueRunnerDef()
                    qr_proto.ParseFromString(serialized_proto)

                    # Try to find dequeue ops
                    dequeue_ops = []
                    for node in sess.graph_def.node:
                        for inp_name in node.input:
                            inp_name, _, _ = self.get_node_name_and_output_index(
                                inp_name)
                            if (inp_name == qr_proto.queue_name
                                    and "dequeue" in node.name.lower()):
                                dequeue_ops.append(
                                    sess.graph.get_operation_by_name(
                                        node.name))

                    # Input tensor names for the graphs will be output tensors of the
                    # dequeue ops
                    for dequeue_op in dequeue_ops:
                        for out_ten in dequeue_op.outputs:
                            self._input_tensor_names.append(out_ten.name)
                            self._required_nodes.add(dequeue_op.node_def.name)

            # Just print a warning if we could not find any inputs
            if not len(self._input_tensor_names):
                logging.warning("Did not find any input tensors in the graph")

            read_graph_fn(sess)