def _build_graph(graph, names): graph_def = tf_pb.GraphDef() for tensor in graph.tensors: if tensor.producer is None: node_def = graph_def.node.add() node_def.name = names[tensor] if tensor.data is None: node_def.op = 'Placeholder' _build_shape(node_def.attr['shape'].shape, tensor.shape) node_def.attr['dtype'].type = _build_dtype(tensor.dtype) else: node_def.op = 'Const' _build_attribute(node_def.attr['value'], tensor) node_def.attr['dtype'].type = _build_dtype(tensor.dtype) for operation in graph.operations: node_def = graph_def.node.add() _build_node(node_def, operation, names) return graph_def
def read_tf_graph_from_protobuf(filename): graph_def = tf_pb.GraphDef() with open(filename, 'rb') as file: graph_def.ParseFromString(file.read()) graph = TFGraph() attrib_graph = TFGraph() # just a graph to contain the tensors that are in attributes, no need to return this attributes_by_node_name = {} outputs_by_node_name = {} detected_output_count = _detect_output_counts(graph_def) for node in graph_def.node: outputs = [] attributes = _get_attributes(node.attr, attrib_graph) output_count = _OutputCount.get(node.op, 1) if isinstance(output_count, str): output_count = attributes[output_count] output_count = max(output_count, detected_output_count.get(node.op, 1)) assert isinstance(output_count, int) if output_count >= 1: output = TFTensor(graph, utils.anystr_to_str(node.name)) outputs.append(output) for i in range(1, output_count): tensor_name = utils.anystr_to_str(node.name) + ':' + str(i) output = TFTensor(graph, tensor_name) outputs.append(output) outputs_by_node_name[node.name] = outputs attributes_by_node_name[node.name] = attributes tensor_by_name = {tensor.name: tensor for outputs in six.itervalues(outputs_by_node_name) for tensor in outputs} placeholders = [] for node in graph_def.node: attributes = attributes_by_node_name[node.name] outputs = outputs_by_node_name[node.name] if node.op == 'Placeholder': assert len(outputs) == 1 tensor = outputs[0] tensor.shape = attributes['shape'] if 'shape' in attributes else None tensor.dtype = attributes['dtype'] if 'dtype' in attributes else None placeholders.append(tensor) elif node.op == 'Const': assert len(outputs) == 1 tensor = outputs[0] value = attributes['value'] if isinstance(value, TFTensor): tensor.shape = value.shape tensor.dtype = value.dtype tensor.data = value.data else: tensor.data = value else: input_names = [name[:-2] if name.endswith(':0') else name for name in node.input if not name.startswith('^')] for name in input_names: if name not in tensor_by_name: print('Info: List of node types in graph: {}\n'.format( sorted(list({node.op for node in graph_def.node})))) raise utils.NNEFToolsException( "Tensor {} is used, but it is not clear which operation produced it. " "Probably the graph has unsupported dynamic operations.".format(name)) inputs = tuple([tensor_by_name[name] for name in input_names]) TFOperation(graph, name=utils.anystr_to_str(node.op), inputs=inputs, outputs=outputs, attribs=attributes) for tensor in graph.tensors: if tensor.name is not None and ':' not in tensor.name: tensor.name += ':0' graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders]) graph_outputs = [] for op in graph.operations: if all(len(output.consumers) == 0 for output in op.outputs): for output in op.outputs: graph_outputs.append(output) graph.outputs = OrderedDict([('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor) for i, tensor in enumerate(graph_outputs)]) return graph
def read_tf_graph_from_protobuf(filename): graph_def = tf_pb.GraphDef() with open(filename, 'rb') as file: graph_def.ParseFromString(file.read()) graph = TFGraph() # just a graph to contain the tensors that are in attributes # no need to return this attrib_graph = TFGraph() attributes_by_node_id = {} outputs_by_node_id = {} for node in graph_def.node: outputs = [] attributes = _get_attributes(node.attr, attrib_graph) output_count = _OutputCount.get(node.op, 1) if isinstance(output_count, str): output_count = attributes[output_count] assert isinstance(output_count, int) if output_count >= 1: output = TFTensor(graph, utils.anystr_to_str(node.name)) outputs.append(output) for i in range(1, output_count): tensor_name = utils.anystr_to_str(node.name) + ':' + str(i) output = TFTensor(graph, tensor_name) outputs.append(output) outputs_by_node_id[id(node)] = outputs attributes_by_node_id[id(node)] = attributes tensor_by_name = { tensor.name: tensor for outputs in six.itervalues(outputs_by_node_id) for tensor in outputs } placeholders = [] for node in graph_def.node: attributes = attributes_by_node_id[id(node)] outputs = outputs_by_node_id[id(node)] if node.op == 'Placeholder': assert len(outputs) == 1 tensor = outputs[0] tensor.shape = attributes[ 'shape'] if 'shape' in attributes else None tensor.dtype = attributes[ 'dtype'] if 'dtype' in attributes else None placeholders.append(tensor) elif node.op == 'Const': assert len(outputs) == 1 tensor = outputs[0] value = attributes['value'] if isinstance(value, TFTensor): tensor.shape = value.shape tensor.dtype = value.dtype tensor.data = value.data else: tensor.data = value else: inputs = tuple([tensor_by_name[name] for name in node.input]) TFOperation(graph, name=utils.anystr_to_str(node.op), inputs=inputs, outputs=outputs, attribs=attributes) for tensor in graph.tensors: if tensor.name is not None and ':' not in tensor.name: tensor.name += ':0' graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders]) graph_outputs = [] for op in graph.operations: if all(len(output.consumers) == 0 for output in op.outputs): for output in op.outputs: graph_outputs.append(output) graph.outputs = OrderedDict([ ('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor) for i, tensor in enumerate(graph_outputs) ]) return graph