コード例 #1
0
def _build_graph(graph, names):
    graph_def = tf_pb.GraphDef()

    for tensor in graph.tensors:
        if tensor.producer is None:
            node_def = graph_def.node.add()
            node_def.name = names[tensor]
            if tensor.data is None:
                node_def.op = 'Placeholder'
                _build_shape(node_def.attr['shape'].shape, tensor.shape)
                node_def.attr['dtype'].type = _build_dtype(tensor.dtype)
            else:
                node_def.op = 'Const'
                _build_attribute(node_def.attr['value'], tensor)
                node_def.attr['dtype'].type = _build_dtype(tensor.dtype)

    for operation in graph.operations:
        node_def = graph_def.node.add()
        _build_node(node_def, operation, names)

    return graph_def
コード例 #2
0
def read_tf_graph_from_protobuf(filename):
    graph_def = tf_pb.GraphDef()
    with open(filename, 'rb') as file:
        graph_def.ParseFromString(file.read())

    graph = TFGraph()
    attrib_graph = TFGraph()  # just a graph to contain the tensors that are in attributes, no need to return this

    attributes_by_node_name = {}
    outputs_by_node_name = {}
    detected_output_count = _detect_output_counts(graph_def)
    for node in graph_def.node:
        outputs = []
        attributes = _get_attributes(node.attr, attrib_graph)
        output_count = _OutputCount.get(node.op, 1)
        if isinstance(output_count, str):
            output_count = attributes[output_count]
        output_count = max(output_count, detected_output_count.get(node.op, 1))
        assert isinstance(output_count, int)
        if output_count >= 1:
            output = TFTensor(graph, utils.anystr_to_str(node.name))
            outputs.append(output)
            for i in range(1, output_count):
                tensor_name = utils.anystr_to_str(node.name) + ':' + str(i)
                output = TFTensor(graph, tensor_name)
                outputs.append(output)

        outputs_by_node_name[node.name] = outputs
        attributes_by_node_name[node.name] = attributes

    tensor_by_name = {tensor.name: tensor
                      for outputs in six.itervalues(outputs_by_node_name)
                      for tensor in outputs}

    placeholders = []
    for node in graph_def.node:
        attributes = attributes_by_node_name[node.name]
        outputs = outputs_by_node_name[node.name]

        if node.op == 'Placeholder':
            assert len(outputs) == 1
            tensor = outputs[0]
            tensor.shape = attributes['shape'] if 'shape' in attributes else None
            tensor.dtype = attributes['dtype'] if 'dtype' in attributes else None
            placeholders.append(tensor)
        elif node.op == 'Const':
            assert len(outputs) == 1
            tensor = outputs[0]
            value = attributes['value']
            if isinstance(value, TFTensor):
                tensor.shape = value.shape
                tensor.dtype = value.dtype
                tensor.data = value.data
            else:
                tensor.data = value
        else:
            input_names = [name[:-2] if name.endswith(':0') else name for name in node.input if not name.startswith('^')]
            for name in input_names:
                if name not in tensor_by_name:
                    print('Info: List of node types in graph: {}\n'.format(
                        sorted(list({node.op for node in graph_def.node}))))

                    raise utils.NNEFToolsException(
                        "Tensor {} is used, but it is not clear which operation produced it. "
                        "Probably the graph has unsupported dynamic operations.".format(name))
            inputs = tuple([tensor_by_name[name] for name in input_names])
            TFOperation(graph,
                        name=utils.anystr_to_str(node.op),
                        inputs=inputs,
                        outputs=outputs,
                        attribs=attributes)

    for tensor in graph.tensors:
        if tensor.name is not None and ':' not in tensor.name:
            tensor.name += ':0'

    graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders])
    graph_outputs = []
    for op in graph.operations:
        if all(len(output.consumers) == 0 for output in op.outputs):
            for output in op.outputs:
                graph_outputs.append(output)

    graph.outputs = OrderedDict([('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor)
                                 for i, tensor in enumerate(graph_outputs)])

    return graph
コード例 #3
0
ファイル: tf_pb_io.py プロジェクト: hansely/NNEF-Tools
def read_tf_graph_from_protobuf(filename):
    graph_def = tf_pb.GraphDef()
    with open(filename, 'rb') as file:
        graph_def.ParseFromString(file.read())

    graph = TFGraph()
    # just a graph to contain the tensors that are in attributes
    # no need to return this
    attrib_graph = TFGraph()

    attributes_by_node_id = {}
    outputs_by_node_id = {}

    for node in graph_def.node:
        outputs = []
        attributes = _get_attributes(node.attr, attrib_graph)
        output_count = _OutputCount.get(node.op, 1)
        if isinstance(output_count, str):
            output_count = attributes[output_count]
        assert isinstance(output_count, int)
        if output_count >= 1:
            output = TFTensor(graph, utils.anystr_to_str(node.name))
            outputs.append(output)
            for i in range(1, output_count):
                tensor_name = utils.anystr_to_str(node.name) + ':' + str(i)
                output = TFTensor(graph, tensor_name)
                outputs.append(output)
        outputs_by_node_id[id(node)] = outputs
        attributes_by_node_id[id(node)] = attributes

    tensor_by_name = {
        tensor.name: tensor
        for outputs in six.itervalues(outputs_by_node_id) for tensor in outputs
    }
    placeholders = []
    for node in graph_def.node:
        attributes = attributes_by_node_id[id(node)]
        outputs = outputs_by_node_id[id(node)]

        if node.op == 'Placeholder':
            assert len(outputs) == 1
            tensor = outputs[0]
            tensor.shape = attributes[
                'shape'] if 'shape' in attributes else None
            tensor.dtype = attributes[
                'dtype'] if 'dtype' in attributes else None
            placeholders.append(tensor)
        elif node.op == 'Const':
            assert len(outputs) == 1
            tensor = outputs[0]
            value = attributes['value']
            if isinstance(value, TFTensor):
                tensor.shape = value.shape
                tensor.dtype = value.dtype
                tensor.data = value.data
            else:
                tensor.data = value
        else:
            inputs = tuple([tensor_by_name[name] for name in node.input])
            TFOperation(graph,
                        name=utils.anystr_to_str(node.op),
                        inputs=inputs,
                        outputs=outputs,
                        attribs=attributes)

    for tensor in graph.tensors:
        if tensor.name is not None and ':' not in tensor.name:
            tensor.name += ':0'

    graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor)
                                for tensor in placeholders])
    graph_outputs = []
    for op in graph.operations:
        if all(len(output.consumers) == 0 for output in op.outputs):
            for output in op.outputs:
                graph_outputs.append(output)

    graph.outputs = OrderedDict([
        ('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor)
        for i, tensor in enumerate(graph_outputs)
    ])

    return graph