Ejemplo n.º 1
0
def _nnefdog_to_source_impl(nnefdog,
                            file_handle,
                            custom_fragments="",
                            enable_shape_of=False):
    # type: (NnefGraph, TextIO, str, bool)->TextIO

    f = file_handle
    fix_str = utils.ensure_not_unicode_in_python2
    indent = 4 * " "

    print(nnef.format_version((1, 0)), file=f)

    extensions = []
    if enable_shape_of:
        extensions.append("KHR_enable_operator_expressions")
    print(nnef.format_extensions(extensions), file=f)

    print(file=f)

    if custom_fragments:
        print(custom_fragments, file=f)
        print(file=f)

    graph_params = nnef.format_graph(
        name=fix_str(nnefdog.name),
        inputs=[fix_str(name) for name in nnefdog.input_dn_names],
        outputs=[fix_str(name) for name in nnefdog.output_dn_names])

    print("graph {}".format(graph_params), file=f)
    print("{", file=f)

    for op in nnefdog.ops:
        dtype = op.result.dtype if op.name in [
            "external", "constant", "variable"
        ] else None
        invocation = nnef.format_invocation(name=fix_str(op.name),
                                            args=[],
                                            kwargs=_preprocess_args(op.args),
                                            results=_results_to_result_names(
                                                op.results.values()),
                                            dtype=dtype)

        comments = utils.without_nones(
            [dn.extra.get(dog.EXTRA_COMMENT) for dn in op.get_result_nodes()])
        comment = "  # {}".format(", ".join(comments)) if comments else ""
        print("{}{};{}".format(indent, invocation, comment), file=f)

    print("}", file=f)

    return file_handle
Ejemplo n.º 2
0
    if args.input_shape:
        input_shape = eval(args.input_shape)
        if not isinstance(input_shape, (list, dict)):
            print("input-shape must be Python list or dict expression")
            exit(-1)

        for op in graph.operations:
            if op.name == 'external':
                if isinstance(input_shape, dict):
                    name = op.outputs['output']
                    if name in input_shape:
                        op.attribs['shape'] = input_shape[name]
                else:
                    op.attribs['shape'] = input_shape

    if args.shapes:
        try:
            nnef.infer_shapes(graph)
        except nnef.Error as err:
            print('Shape error: ' + str(err))
            exit(-1)

    print(
        nnef.format_graph(graph.name,
                          graph.inputs,
                          graph.outputs,
                          graph.operations,
                          graph.tensors,
                          annotate_shapes=args.shapes))
    print('Validation succeeded')
Ejemplo n.º 3
0
# See the License for the specific language governing permissions and
# limitations under the License.

import nnef


def shuffle_shape(op, args, shapes):
    shapes[args['output']] = shapes[args['input']]


nnef._register_custom_ops(
    "shuffle",
    "fragment shuffle<?>( input: tensor<?>, groups: integer ) -> ( output: tensor<?> );"
)
nnef._register_custom_shapes({"shuffle": shuffle_shape})

graph = nnef.parse_string("""
    version 1.0;
    graph Net( input ) -> ( output )
    {
        input = external(shape = [1,3,224,224]);
        filter = variable(shape = [32,3,5,5], label = 'conv/filter');
        conv = conv(input, filter);
        output = shuffle(conv, groups = 4);
    }
    """)

print(
    nnef.format_graph(graph.name, graph.inputs, graph.outputs,
                      graph.operations))
Ejemplo n.º 4
0
# Copyright (c) 2017 The Khronos Group Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nnef


graph = nnef.parse_string(
    """
    version 1.0;
    graph Net( input ) -> ( output )
    {
        input = external(shape = [1,3,224,224]);
        filter = variable(shape = [32,3,5,5], label = 'conv/filter');
        output = conv(input, filter);
    }
    """
)

print(nnef.format_graph(graph.name, graph.inputs, graph.outputs, graph.operations, graph.tensors))