def summarize_graph(graph_def): placeholders = dict() outputs = list() graph = tf_v1.Graph() with graph.as_default(): # pylint: disable=not-context-manager tf_v1.import_graph_def(graph_def, name='') for node in graph.as_graph_def().node: # pylint: disable=no-member if node.op == 'Placeholder': node_dict = dict() node_dict['type'] = tf_v1.DType(node.attr['dtype'].type).name node_dict['shape'] = str(tf_v1.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1') placeholders[node.name] = node_dict if len(children(node.name, graph)) == 0: if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types: outputs.append(node.name) result = dict() result['inputs'] = placeholders result['outputs'] = outputs return result
def attr_value_to_python_type( attr_value, # type: tf.AttrValue attr_name # type: String ): # type (...) -> Any """ Inverse of python_type_to_attr_value(). Args: attr_value: Protocol buffer version of a node's attribute value Returns: A Python object or built-in type corresponding to the field in `attr_value` that is in use. """ # TODO(frreiss): Handle AttrValues that are lists if attr_value.HasField("s"): # str # TODO(frreiss): Should we return the binary value here? return tf.compat.as_str(attr_value.s) elif attr_value.HasField("i"): # int return attr_value.i elif attr_value.HasField("f"): # float return attr_value.f elif attr_value.HasField("b"): # bool return attr_value.b elif attr_value.HasField("type"): # DType return tf.DType(attr_value.type) elif attr_value.HasField("shape"): # TensorShape # Undocumented behavior of public API: tf.TensorShape constructor accepts # a TensorShapeProto. return tf.TensorShape(attr_value.shape) elif attr_value.HasField("tensor"): # TensorProto return tf.make_ndarray(attr_value.tensor) elif attr_value.HasField("list"): # list return attr_value.list elif attr_value.HasField("func"): # func return attr_value.func # TODO(frreiss): Convert the "placeholder" fields of the union here else: raise ValueError("Don't know how to convert AttrValue {} to " "a Python object for attribute {}".format( attr_value, attr_name))