def get_node_map_in_graph( meta_graph_def: meta_graph_pb2.MetaGraphDef, prefix: str, node_suffixes: List[str], graph: tf.Graph ) -> Dict[types.FPLKeyType, Dict[str, types.TensorType]]: """Like get_node_map, but looks up the nodes in the given graph. Args: meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the structure from. prefix: Prefix for the CollectionDef names. node_suffixes: The suffixes to the prefix to form the names of the CollectionDefs to extract the nodes from, e.g. in the example described above, node_suffixes would be ['suffix_a', 'suffix_b']. graph: TensorFlow graph to lookup the nodes in. Returns: A dictionary of dictionaries like get_node_map, except the values are the actual nodes in the graph. """ node_map = get_node_map(meta_graph_def, prefix, node_suffixes) result = {} for key, elems in node_map.items(): result[key] = { k: encoding.decode_tensor_node(graph, n) for k, n in elems.items() } return result
def testEncodeDecodeTensorNode(self): g = tf.Graph() with g.as_default(): example = tf.placeholder(tf.string, name='example') features = tf.parse_example( example, { 'age': tf.FixedLenFeature( [], dtype=tf.int64, default_value=-1), 'gender': tf.FixedLenFeature([], dtype=tf.string), 'varstr': tf.VarLenFeature(tf.string), 'varint': tf.VarLenFeature(tf.int64), 'varfloat': tf.VarLenFeature(tf.float32), u'unicode\u1234': tf.FixedLenFeature([], dtype=tf.string), }) constant = tf.constant(1.0) sparse = tf.SparseTensor(indices=tf.placeholder(tf.int64), values=tf.placeholder(tf.int64), dense_shape=tf.placeholder(tf.int64)) test_cases = [ example, features['age'], features['gender'], features['varstr'], features['varint'], features['varfloat'], features[u'unicode\u1234'], constant, sparse ] for tensor in test_cases: got_tensor = encoding.decode_tensor_node( g, encoding.encode_tensor_node(tensor)) if isinstance(tensor, tf.SparseTensor): self.assertEqual(tensor.indices, got_tensor.indices) self.assertEqual(tensor.values, got_tensor.values) self.assertEqual(tensor.dense_shape, got_tensor.dense_shape) else: self.assertEqual(tensor, got_tensor)
def get_node_in_graph(meta_graph_def, path, graph): """Like get_node_wrapped_tensor_info, but looks up the node in the graph. Args: meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the node name from. path: Name of the collection containing the node name. graph: TensorFlow graph to lookup the nodes in. Returns: The node in the graph with the name returned by get_node_wrapped_tensor_info. """ return encoding.decode_tensor_node( graph, get_node_wrapped_tensor_info(meta_graph_def, path))