Exemplo n.º 1
0
def get_node_map_in_graph(
        meta_graph_def: meta_graph_pb2.MetaGraphDef, prefix: str,
        node_suffixes: List[str], graph: tf.Graph
) -> Dict[types.FPLKeyType, Dict[str, types.TensorType]]:
    """Like get_node_map, but looks up the nodes in the given graph.

  Args:
     meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
       structure from.
     prefix: Prefix for the CollectionDef names.
     node_suffixes: The suffixes to the prefix to form the names of the
       CollectionDefs to extract the nodes from, e.g. in the example described
       above, node_suffixes would be ['suffix_a', 'suffix_b'].
     graph: TensorFlow graph to lookup the nodes in.

  Returns:
    A dictionary of dictionaries like get_node_map, except the values are
    the actual nodes in the graph.
  """
    node_map = get_node_map(meta_graph_def, prefix, node_suffixes)
    result = {}
    for key, elems in node_map.items():
        result[key] = {
            k: encoding.decode_tensor_node(graph, n)
            for k, n in elems.items()
        }
    return result
Exemplo n.º 2
0
    def testEncodeDecodeTensorNode(self):
        g = tf.Graph()
        with g.as_default():
            example = tf.placeholder(tf.string, name='example')
            features = tf.parse_example(
                example, {
                    'age': tf.FixedLenFeature(
                        [], dtype=tf.int64, default_value=-1),
                    'gender': tf.FixedLenFeature([], dtype=tf.string),
                    'varstr': tf.VarLenFeature(tf.string),
                    'varint': tf.VarLenFeature(tf.int64),
                    'varfloat': tf.VarLenFeature(tf.float32),
                    u'unicode\u1234': tf.FixedLenFeature([], dtype=tf.string),
                })
            constant = tf.constant(1.0)
            sparse = tf.SparseTensor(indices=tf.placeholder(tf.int64),
                                     values=tf.placeholder(tf.int64),
                                     dense_shape=tf.placeholder(tf.int64))

        test_cases = [
            example, features['age'], features['gender'], features['varstr'],
            features['varint'], features['varfloat'],
            features[u'unicode\u1234'], constant, sparse
        ]
        for tensor in test_cases:
            got_tensor = encoding.decode_tensor_node(
                g, encoding.encode_tensor_node(tensor))
            if isinstance(tensor, tf.SparseTensor):
                self.assertEqual(tensor.indices, got_tensor.indices)
                self.assertEqual(tensor.values, got_tensor.values)
                self.assertEqual(tensor.dense_shape, got_tensor.dense_shape)
            else:
                self.assertEqual(tensor, got_tensor)
Exemplo n.º 3
0
def get_node_in_graph(meta_graph_def, path, graph):
    """Like get_node_wrapped_tensor_info, but looks up the node in the graph.

  Args:
     meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
       node name from.
     path: Name of the collection containing the node name.
     graph: TensorFlow graph to lookup the nodes in.

  Returns:
    The node in the graph with the name returned by
    get_node_wrapped_tensor_info.
  """
    return encoding.decode_tensor_node(
        graph, get_node_wrapped_tensor_info(meta_graph_def, path))