def _nnefdog_to_source_impl(nnefdog,
                            file_handle,
                            custom_fragments="",
                            enable_shape_of=False):
    # type: (NnefGraph, TextIO, str, bool)->TextIO

    f = file_handle
    fix_str = utils.ensure_not_unicode_in_python2
    indent = 4 * " "

    print(nnef.format_version((1, 0)), file=f)

    extensions = []
    if enable_shape_of:
        extensions.append("KHR_enable_operator_expressions")
    print(nnef.format_extensions(extensions), file=f)

    print(file=f)

    if custom_fragments:
        print(custom_fragments, file=f)
        print(file=f)

    graph_params = nnef.format_graph(
        name=fix_str(nnefdog.name),
        inputs=[fix_str(name) for name in nnefdog.input_dn_names],
        outputs=[fix_str(name) for name in nnefdog.output_dn_names])

    print("graph {}".format(graph_params), file=f)
    print("{", file=f)

    for op in nnefdog.ops:
        dtype = op.result.dtype if op.name in [
            "external", "constant", "variable"
        ] else None
        invocation = nnef.format_invocation(name=fix_str(op.name),
                                            args=[],
                                            kwargs=_preprocess_args(op.args),
                                            results=_results_to_result_names(
                                                op.results.values()),
                                            dtype=dtype)

        comments = utils.without_nones(
            [dn.extra.get(dog.EXTRA_COMMENT) for dn in op.get_result_nodes()])
        comment = "  # {}".format(", ".join(comments)) if comments else ""
        print("{}{};{}".format(indent, invocation, comment), file=f)

    print("}", file=f)

    return file_handle
Example #2
0
def _nnefdog_to_source_impl(nnefdog, file_handle, custom_fragments=""):
    # type: (NnefGraph, TextIO, str)->TextIO

    f = file_handle
    fix_str = utils.ensure_not_unicode_in_python2
    indent = 4 * " "

    print(nnef.format_version((1, 0)), file=f)

    extensions = []
    print(nnef.format_extensions(extensions), file=f)

    print(file=f)

    if custom_fragments:
        print(custom_fragments, file=f)
        print(file=f)

    graph_name = fix_str(nnefdog.name)
    graph_inputs = [fix_str(name) for name in nnefdog.input_dn_names]
    graph_outputs = [fix_str(name) for name in nnefdog.output_dn_names]

    print("graph {}({}) -> ({})".format(graph_name, ', '.join(graph_inputs),
                                        ', '.join(graph_outputs)),
          file=f)
    print("{", file=f)

    for op in nnefdog.ops:
        dtype = op.result.dtype if op.name in [
            "external", "constant", "variable"
        ] else None
        invocation = nnef.format_invocation(name=fix_str(op.name),
                                            attribs=_preprocess_args(op.args),
                                            inputs=tuple(),
                                            outputs=_results_to_result_names(
                                                op.results.values()),
                                            dtype=dtype)

        comments = utils.without_nones(
            [op.extra.get(dog.EXTRA_COMMENT)] +
            [dn.extra.get(dog.EXTRA_COMMENT) for dn in op.get_result_nodes()])
        comment = "  # {}".format(", ".join(comments)) if comments else ""
        print("{}{};{}".format(indent, invocation, comment), file=f)

    print("}", file=f)

    return file_handle
Example #3
0
def _print(graph, file, extensions, fragments, version_custom_ops,
           annotate_shapes):
    assert graph.is_sorted(), "graph must be topologically sorted"
    assert all(tensor.name is not None or (tensor.producer is None and tensor.data is not None)
               for tensor in graph.tensors), \
        "all tensors must have names"
    assert all(all(s is not None for s in op.attribs['shape'])
               for op in graph.operations if op.type == 'external'), \
        "external ops must not contain undefined shapes"

    print(nnef.format_version((1, 0)), file=file)
    if len(extensions):
        print(file=file)
        print(nnef.format_extensions(extensions), file=file)
    if fragments:
        print(file=file)
        print(fragments, file=file)
    print(file=file)

    graph_name = as_str(graph.name) if graph.name is not None else "G"
    graph_inputs = [as_str(item.name) for item in graph.inputs]
    graph_outputs = [as_str(item.name) for item in graph.outputs]

    print("graph {}({}) -> ({})".format(graph_name, ', '.join(graph_inputs),
                                        ', '.join(graph_outputs)),
          file=file)
    print("{", file=file)

    versions = {}
    for op in graph.operations:
        assert all(isinstance(item, Tensor) for item in op.outputs)

        inputs = ((from_numpy(item.data) if item.producer is None else
                   nnef.Identifier(as_str(item.name))) if isinstance(
                       item, Tensor) else item for item in op.inputs)
        inputs = tuple(inputs) if isinstance(op.inputs,
                                             tuple) else (list(inputs), )

        outputs = (nnef.Identifier(as_str(item.name)) for item in op.outputs)
        outputs = tuple(outputs) if isinstance(op.outputs,
                                               tuple) else (list(outputs), )

        attribs = {
            as_str(key): value
            for key, value in six.iteritems(op.attribs)
        }

        name = _next_version(
            op.type, versions
        ) if op.type not in nnef.StandardOperations and version_custom_ops else op.type

        dtype = attribs.get('dtype')
        if dtype is not None:
            dtype = _nnef_dtype(dtype)
            del attribs['dtype']

        for key, value in six.iteritems(attribs):
            if isinstance(value, (type, np.dtype)):
                attribs[key] = _nnef_dtype(value)

        invocation = nnef.format_invocation(name=name,
                                            dtype=dtype,
                                            attribs=attribs,
                                            inputs=inputs,
                                            outputs=outputs)
        annotation = "    # " + ", ".join(_nnef_dtype(output.dtype) + str(output.shape) for output in op.outputs) \
            if annotate_shapes else ''

        print("    {};{}".format(invocation, annotation), file=file)

    print("}", file=file)
Example #4
0
def _print(
        nnef_graph,  # type: NNEFGraph
        file_handle,  # type: typing.TextIO
        extensions=None,  # type: typing.Optional[typing.List[str]]
        fragments=None,  # type: typing.Optional[str]
        only_print_used_fragments=False,  # type: bool
):
    # type: (...)->None

    generate_source_operations(nnef_graph)
    nnef_graph.sort()
    try:
        if extensions is None:
            extensions = []

        if fragments is None:
            fragments = ""

        fragments = add_tflite_quantization_fragment_if_needed(
            nnef_graph, fragments)

        if only_print_used_fragments:
            fragments = get_used_fragments(nnef_graph, fragments)

        if fragments:
            if "KHR_enable_fragment_definitions" not in extensions:
                extensions.append("KHR_enable_fragment_definitions")
            if "KHR_enable_operator_expressions" not in extensions:
                extensions.append("KHR_enable_operator_expressions")

        f = file_handle
        indent = 4 * " "

        print(nnef.format_version((1, 0)), file=f)
        if extensions:
            print(nnef.format_extensions(extensions), file=f)
        if fragments:
            print(file=f)
            print(fragments, file=f)
        print(file=f)

        graph_name = _recursive_check_str(
            nnef_graph.name) if nnef_graph.name is not None else "network"
        graph_inputs = _recursive_check_str(
            [input_.name for input_ in nnef_graph.inputs])
        graph_outputs = _recursive_check_str(
            [output_.name for output_ in nnef_graph.outputs])

        print("graph {}({}) -> ({})".format(graph_name,
                                            ', '.join(graph_inputs),
                                            ', '.join(graph_outputs)),
              file=f)
        print("{", file=f)

        for op in nnef_graph.operations:
            inputs = _transform_inputs_before_print(
                list(op.inputs) if not isinstance(op.inputs, tuple) else op.
                inputs)
            dtype = op.output.dtype if op.name in [
                "external", "constant", "variable"
            ] else None
            invocation = nnef.format_invocation(
                name=_recursive_check_str(op.name),
                attribs=_recursive_check_str(_sorted_ordered_dict(op.attribs)),
                inputs=_recursive_check_str(
                    [inputs] if isinstance(inputs, list) else list(inputs)),
                outputs=_recursive_check_str(
                    _result_to_identifiers(
                        list(op.outputs)
                        if not isinstance(op.outputs, tuple) else op.outputs)),
                dtype=_recursive_check_str(dtype))

            comment = "  # {}".format(_recursive_check_str(
                op.comment)) if op.comment else ""
            print("{}{};{}".format(indent, invocation, comment), file=f)

        print("}", file=f)
    finally:
        remove_source_operations(nnef_graph)