Ejemplo n.º 1
0
def _read(parser_graph, with_weights=True):
    # type: (typing.Any, bool)->NNEFGraph

    tensor_by_name = {}
    g = NNEFGraph(name=parser_graph.name)

    def add_to_tensor_by_name(tensor):
        assert tensor.name not in tensor_by_name, "Tensor {} defined multiple times".format(
            tensor.name)
        tensor_by_name[tensor.name] = tensor

    def transform_input(input_):
        if isinstance(input_, nnef.Identifier):
            assert str(
                input_
            ) in tensor_by_name, "Tensor {} not defined before use".format(
                str(input_))
            return tensor_by_name[str(input_)]
        else:
            return NNEFTensor(
                graph=g,
                name=None,
                shape=[],
                dtype=NNEFDTypeByNumpyDType[np.array(input_).dtype.name],
                data=[input_])

    def transform_result(result_):
        if isinstance(result_, nnef.Identifier):
            quantization = parser_graph.tensors[str(result_)].quantization
            if quantization:
                quantization = NNEFQuantization(name=quantization['op-name'],
                                                attribs=quantization)
                del quantization.attribs['op-name']
            else:
                quantization = None

            tensor = NNEFTensor(graph=g,
                                name=str(result_),
                                shape=list(
                                    parser_graph.tensors[str(result_)].shape),
                                dtype=parser_graph.tensors[str(result_)].dtype,
                                quantization=quantization)

            add_to_tensor_by_name(tensor)
            return tensor
        else:
            return result_

    for parser_op in parser_graph.operations:

        inputs = utils.recursive_transform(parser_op.inputs, transform_input)
        if any(isinstance(i, list) for i in six.itervalues(inputs)):
            inputs = utils.recursive_collect(inputs)
        else:
            inputs = tuple(utils.recursive_collect(inputs))

        outputs = utils.recursive_transform(parser_op.outputs,
                                            transform_result)
        if any(isinstance(o, list) for o in six.itervalues(outputs)):
            outputs = utils.recursive_collect(outputs)
        else:
            outputs = tuple(utils.recursive_collect(outputs))

        if parser_op.name == "variable":
            outputs[0].label = parser_op.attribs["label"]
            if with_weights:
                outputs[0].data = parser_graph.tensors[
                    parser_op.outputs["output"]].data
                assert outputs[0].data is not None
            else:
                outputs[0].data = np.array(
                    [], dtype=NumpyDTypeByNNEFDType[parser_op.dtype])
        if parser_op.name == "constant":
            outputs[0].data = parser_op.attribs["value"]

        if parser_op.name not in ["external", "constant", "variable"]:
            NNEFOperation(graph=g,
                          name=parser_op.name,
                          attribs=dict(parser_op.attribs),
                          inputs=inputs,
                          outputs=outputs)

    input_tensors = []

    for input_ in parser_graph.inputs:
        assert str(
            input_
        ) in tensor_by_name, "Input tensor {} was not declared".format(
            str(input_))
        input_tensors.append(tensor_by_name[str(input_)])

    output_tensors = []

    for output_ in parser_graph.outputs:
        assert str(
            output_
        ) in tensor_by_name, "Output tensor {} was not declared".format(
            str(output_))
        output_tensors.append(tensor_by_name[str(output_)])

    g.inputs = OrderedDict((t.name, t) for t in input_tensors)
    g.outputs = OrderedDict((t.name, t) for t in output_tensors)

    g.generate_missing_names()
    return g
Ejemplo n.º 2
0
def _get_outputs(result):
    # type: (typing.Any)->typing.Union[typing.List, typing.Tuple]
    is_list = isinstance(result, list)
    outputs = utils.recursive_collect(result)
    return outputs if is_list else tuple(outputs)