Esempio n. 1
0
def main() -> None:
    base_dir = os.path.dirname(
        os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
    docs_dir = os.path.join(base_dir, 'docs')
    schemas = defs.get_all_schemas()

    has_ml = is_ml(schemas)
    fname = os.path.join(docs_dir, 'TestCoverage.md')
    with io.open(fname, 'w+', newline='',
                 encoding="utf-8") as f:  # type: ignore
        gen_spdx(f)
        gen_outlines(f, False)
        gen_node_test_coverage(schemas, f, False)
        gen_model_test_coverage(schemas, f, False)
        gen_overall_test_coverage(schemas, f, False)

    if has_ml:
        fname = os.path.join(docs_dir, 'TestCoverage-ml.md')
        with io.open(fname, 'w+', newline='',
                     encoding="utf-8") as f:  # type: ignore
            gen_spdx(f)
            gen_outlines(f, True)
            gen_node_test_coverage(schemas, f, True)
            gen_model_test_coverage(schemas, f, True)
            gen_overall_test_coverage(schemas, f, True)
Esempio n. 2
0
def main():
    backend_opset_dict = {}
    frontend_opset_dict = {}

    for schema in defs.get_all_schemas():
        op_name = schema.name
        backend_opset_dict[op_name] = []
        frontend_opset_dict[op_name] = []

    backend_onnx_coverage = get_backend_coverage()
    backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
    frontend_onnx_coverage, frontend_tf_coverage, experimental_op = get_frontend_coverage(
    )
    frontend_opset_dict.update(frontend_onnx_coverage.get(
        defs.ONNX_DOMAIN, {}))

    for exp_op in experimental_op:
        frontend_opset_dict["{} [EXPERIMENTAL]".format(
            exp_op)] = frontend_opset_dict.pop(exp_op)

    with open('opset_version.py', 'w') as version_file:
        pp = pprint.PrettyPrinter(indent=4)
        version_file.write("backend_opset_version = {\n " +
                           pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_opset_version = {\n " +
                           pp.pformat(frontend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write(
            "frontend_tf_opset_version = {\n " +
            pp.pformat(frontend_tf_coverage.get(defs.ONNX_DOMAIN, {}))[1:-1] +
            "\n}\n")
Esempio n. 3
0
def main():
    backend_opset_dict = {}
    frontend_opset_dict = {}

    for schema in defs.get_all_schemas():
        op_name = schema.name
        backend_opset_dict[op_name] = []
        frontend_opset_dict[op_name] = []

    backend_onnx_coverage = get_backend_coverage()
    backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
    frontend_onnx_coverage, frontend_tf_coverage = get_frontend_coverage()
    frontend_opset_dict.update(frontend_onnx_coverage.get(
        defs.ONNX_DOMAIN, {}))

    with open('opset_version.py', 'w') as version_file:
        pp = pprint.PrettyPrinter(indent=4)
        version_file.write("backend_opset_version = {\n " +
                           pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_opset_version = {\n " +
                           pp.pformat(frontend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write(
            "frontend_darknet_opset_version = {\n " +
            pp.pformat(frontend_tf_coverage.get(defs.ONNX_DOMAIN, {}))[1:-1] +
            "\n}\n")
Esempio n. 4
0
def main(args):
    args.output.write('## Operator Schemas\n')
    args.output.write(
        '*This file is automatically generated from the [def files](/onnx/defs)*\n'
    )

    for op_type, schema in sorted(defs.get_all_schemas().items()):
        # If support level is experimental, then don't generate documentation.
        if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
            continue

        # op_type
        s = '* <a name="{}"></a><a name="{}"></a>**{}**\n'.format(
            op_type, op_type.lower(), op_type)

        # doc
        if schema.doc:
            s += '\n'
            s += '\n'.join('  ' + line
                           for line in schema.doc.lstrip().splitlines())
            s += '\n'

        # attributes
        if schema.attributes:
            s += '  * **attribute**:\n'
            s += '    <dl>\n'
            for _, attr in sorted(schema.attributes.items()):
                s += '      <dt>{}</dt>\n'.format(attr.name)
                s += '      <dd>{}</dd>\n'.format(attr.description)
            s += '    </dl>\n'

        # inputs
        s += '  * **input**:'
        if schema.min_input != schema.max_input:
            s += '{} - {}'.format(display_number(schema.min_input),
                                  display_number(schema.max_input))
        s += '\n'
        if schema.input_desc:
            s += '    <dl>\n'
            for input_name, input_desc in schema.input_desc:
                s += '      <dt>{}</dt>\n'.format(input_name)
                s += '      <dd>{}</dd>\n'.format(input_desc)
            s += '    </dl>\n'

        # outputs
        s += '  * **output**:'
        if schema.min_output != schema.max_output:
            s += '{} - {}'.format(display_number(schema.min_output),
                                  display_number(schema.max_output))
        s += '\n'
        if schema.output_desc:
            s += '    <dl>\n'
            for output_name, output_desc in schema.output_desc:
                s += '      <dt>{}</dt>\n'.format(output_name)
                s += '      <dd>{}</dd>\n'.format(output_desc)
            s += '    </dl>\n'

        s += '\n\n'
        args.output.write(s)
def get_onnx_supported_ops():
    onnx_ops_dict = {}
    for schema in defs.get_all_schemas():
        onnx_ops_dict[schema.name] = {
            'version': schema.since_version,
            'deprecated': schema.deprecated
        }
    return onnx_ops_dict
def main():  # type: () -> None
    # domain -> support level -> name -> [schema]
    with_inference = []
    without_inference = []
    for schema in defs.get_all_schemas():
        domain, name, has_inference = schema.domain, schema.name, schema.has_type_and_shape_inference_function
        if has_inference:
            with_inference.append((domain, name))
        else:
            without_inference.append((domain, name))
    print(len(with_inference), 'operators have a type/shape inference function.')
    print(len(without_inference), 'do not. These are:')
    for domain, name in sorted(without_inference):
        print(domain, name)
Esempio n. 7
0
def main():
    clazzes = []
    for schema in defs.get_all_schemas():
        clazz = display_schema(schema)
        clazzes.append(clazz)

    with open('generated_operators.py', 'w') as code:
        code.truncate()  # empty the file

        # write header
        code.write(
            '### THIS FILE IS AUTOMATICALLY GENERATED BY generate_operators.py, DO NOT MODIFY\n'
        )

        # write imports
        code.write('from deep500.onnx_parser.onnx_objects import *\n\n\n')

        # write classes
        for each_class in clazzes:
            code.write(str(each_class))

        # write hashtable
        code.write("ONNX_OPERATIONS = {\n")
        for each_class in clazzes:
            code.write(
                " " * 4 +
                "'{}': {},\n".format(each_class.name.lower(), each_class.name))
        code.write("}\n")

    def write_visitor(raise_exception: bool, prefix: str):
        inner_text = "raise Exception('implement this method')\n\n" if raise_exception else "pass\n\n"
        with open('{}operations_visitor.py'.format(prefix.lower()),
                  'w') as visitor_file:
            visitor_file.truncate()

            # write import
            s = "import abc\n"
            s += "from deep500.lv1.network import Network\n"
            s += "from deep500.utils.onnx_interop.generated_operators import *\n\n\n"
            # write class
            s += 'class {}OperationsVisitor(abc.ABC):\n'.format(prefix)
            for each_class in clazzes:
                t = " " * 4 + "def visit_{}(self, op: {}, network: Network):\n".format(
                    each_class.name.lower(), each_class.name)
                t += " " * 8 + inner_text
                s += t
            visitor_file.write(s)

    write_visitor(True, "")
    write_visitor(False, "Empty")
Esempio n. 8
0
def main():
    backend_opset_dict = {}

    for schema in defs.get_all_schemas():
        op_name = schema.name
        backend_opset_dict[op_name] = []

    backend_onnx_coverage, backend_experimental_op = get_backend_coverage()
    backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
    backend_ps_dict = get_backend_partial_support_detail()

    with open('opset_version.py', 'w') as version_file:
        pp = pprint.PrettyPrinter(indent=4)
        version_file.write("backend_opset_version = {\n " +
                           pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("backend_partial_support = {\n " +
                           pp.pformat(backend_ps_dict)[1:-1] + "\n}\n")
Esempio n. 9
0
def main():
    # type: () -> None
    base_dir = os.path.dirname(os.path.dirname(os.path.dirname(
        os.path.dirname(os.path.realpath(__file__)))))
    docs_dir = os.path.join(base_dir, 'docs')
    schemas = defs.get_all_schemas()

    ml = is_ml(schemas)
    if ml:
        fname = os.path.join(docs_dir, 'TestCoverage-ml.md')
    else:
        fname = os.path.join(docs_dir, 'TestCoverage.md')

    with open(fname, 'w+') as f:
        gen_outlines(f, ml)
        gen_node_test_coverage(schemas, f, ml)
        gen_model_test_coverage(schemas, f, ml)
        gen_overall_test_coverage(schemas, f, ml)
Esempio n. 10
0
def main():
    # type: () -> None
    base_dir = os.path.dirname(
        os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
    docs_dir = os.path.join(base_dir, 'docs')
    schemas = defs.get_all_schemas()

    ml = is_ml(schemas)
    if ml:
        fname = os.path.join(docs_dir, 'TestCoverage-ml.md')
    else:
        fname = os.path.join(docs_dir, 'TestCoverage.md')

    with open(fname, 'w+') as f:
        gen_outlines(f, ml)
        gen_node_test_coverage(schemas, f, ml)
        gen_model_test_coverage(schemas, f, ml)
        gen_overall_test_coverage(schemas, f, ml)
Esempio n. 11
0
def main(args):
    with io.open(args.output, 'w', newline='') as fout:
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        sorted_ops = sorted(
            (int(schema.support_level), op_type, schema)
            for (op_type, schema) in defs.get_all_schemas().items())

        fout.write('\n')

        # Table of contents
        for _, op_type, schema in sorted_ops:
            s = '* <a href="#{}">{}{}</a>\n'.format(
                op_type, support_level_str(schema.support_level), op_type)
            fout.write(s)

        fout.write('\n')

        for _, op_type, schema in sorted_ops:
            # op_type
            s = '### <a name="{}"></a><a name="{}">**{}{}**</a>\n'.format(
                op_type, op_type.lower(),
                support_level_str(schema.support_level), op_type)

            # doc
            if schema.doc:
                s += '\n'
                s += '\n'.join('  ' + line
                               for line in schema.doc.lstrip().splitlines())
                s += '\n'

            # attributes
            if schema.attributes:
                s += '\n#### Attributes\n\n'
                s += '<dl>\n'
                for _, attr in sorted(schema.attributes.items()):
                    s += '<dt><tt>{}</tt> : {}{}</dt>\n'.format(
                        attr.name, display_attr_type(attr.type),
                        ' (required)' if attr.required else '')
                    s += '<dd>{}</dd>\n'.format(attr.description)
                s += '</dl>\n'

            # inputs
            s += '\n#### Inputs'
            if schema.min_input != schema.max_input:
                s += ' ({} - {})'.format(display_number(schema.min_input),
                                         display_number(schema.max_input))
            s += '\n\n'
            if schema.inputs:
                s += '<dl>\n'
                for input in schema.inputs:
                    s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(
                        input.name, ' (optional)' if input.optional else '',
                        input.typeStr)
                    s += '<dd>{}</dd>\n'.format(input.description)
                s += '</dl>\n'

            # outputs
            s += '\n#### Outputs'
            if schema.min_output != schema.max_output:
                s += ' ({} - {})'.format(display_number(schema.min_output),
                                         display_number(schema.max_output))
            s += '\n\n'

            if schema.outputs:
                s += '<dl>\n'
                for output in schema.outputs:
                    s += '<dt><tt>{}</tt> : {}</dt>\n'.format(
                        output.name, output.typeStr)
                    s += '<dd>{}</dd>\n'.format(output.description)
                s += '</dl>\n'

            # type constraints
            s += '\n#### Type Constraints'
            s += '\n\n'
            if schema.type_constraints:
                s += '<dl>\n'
                for type_constraint in schema.type_constraints:
                    allowedTypes = type_constraint.allowed_type_strs
                    if (len(allowedTypes) > 0):
                        allowedTypeStr = allowedTypes[0]
                    for allowedType in allowedTypes[1:]:
                        allowedTypeStr += ', ' + allowedType
                    s += '<dt><tt>{}</tt> : {}</dt>\n'.format(
                        type_constraint.type_param_str, allowedTypeStr)
                    s += '<dd>{}</dd>\n'.format(type_constraint.description)
                s += '</dl>\n'

            s += '\n\n'

            if op_type in SNIPPETS:
                s += '#### Examples\n\n'
                for summary, code in sorted(SNIPPETS[op_type]):
                    s += '<details>\n'
                    s += '<summary>{}</summary>\n\n'.format(summary)
                    s += '```python\n{}\n```\n\n'.format(code)
                    s += '</details>\n'
                    s += '\n\n'
            fout.write(s)
Esempio n. 12
0
def main(args):
    args.output.write('## Operator Schemas\n')
    args.output.write("""*This file is automatically generated from the
            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).
            Do not modify directly and instead edit operator definitions.*\n"""
                      )

    sorted_ops = sorted(
        (int(schema.support_level), op_type, schema)
        for (op_type, schema) in defs.get_all_schemas().items())

    args.output.write('\n')

    # Table of contents
    for _, op_type, schema in sorted_ops:
        s = '* <a href="#{}">{}{}</a>\n'.format(
            op_type, support_level_str(schema.support_level), op_type)
        args.output.write(s)

    args.output.write('\n')

    for _, op_type, schema in sorted_ops:
        # op_type
        s = '### <a name="{}"></a><a name="{}">**{}{}**</a>\n'.format(
            op_type, op_type.lower(), support_level_str(schema.support_level),
            op_type)

        # doc
        if schema.doc:
            s += '\n'
            s += '\n'.join('  ' + line
                           for line in schema.doc.lstrip().splitlines())
            s += '\n'

        # attributes
        if schema.attributes:
            s += '  * **attribute**:\n'
            s += '    <dl>\n'
            for _, attr in sorted(schema.attributes.items()):
                s += '      <dt>{}</dt>\n'.format(attr.name)
                s += '      <dd>{}</dd>\n'.format(attr.description)
            s += '    </dl>\n'

        # inputs
        s += '  * **input**:'
        if schema.min_input != schema.max_input:
            s += '{} - {}'.format(display_number(schema.min_input),
                                  display_number(schema.max_input))
        s += '\n'
        if schema.input_desc:
            s += '    <dl>\n'
            for input_name, input_desc in schema.input_desc:
                s += '      <dt>{}</dt>\n'.format(input_name)
                s += '      <dd>{}</dd>\n'.format(input_desc)
            s += '    </dl>\n'

        # outputs
        s += '  * **output**:'
        if schema.min_output != schema.max_output:
            s += '{} - {}'.format(display_number(schema.min_output),
                                  display_number(schema.max_output))
        s += '\n'
        if schema.output_desc:
            s += '    <dl>\n'
            for output_name, output_desc in schema.output_desc:
                s += '      <dt>{}</dt>\n'.format(output_name)
                s += '      <dd>{}</dd>\n'.format(output_desc)
            s += '    </dl>\n'

        s += '\n\n'
        args.output.write(s)
Esempio n. 13
0
def get_onnx_supported_ops():
    onnx_opset_dict = {}
    for schema in defs.get_all_schemas():
        op = schema.name
        onnx_opset_dict[op] = schema.since_version
    return onnx_opset_dict
Esempio n. 14
0
from onnx.defs import get_all_schemas
from onnx import NodeProto, GraphProto
from google.protobuf import text_format
import onnx.helper

nodes = []
schemas = get_all_schemas()


def load_node(input_str):
    """
    Return a node
    :param input_str:
    :return:
    """
    node_proto = NodeProto()
    text_format.Parse(input_str, node_proto)
    return node_proto


# default values for each type for serialization


def convert_attr_type_to_enum(attr_value):
    """
    Pass in an attribute from OpDescriptor and
    get back out the equivalent enum value
    for conversion to an attribute proto.
    :param attr_value:  the attribute value
    :return:
    """
Esempio n. 15
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from collections import defaultdict
import os

from tabulate import tabulate  # type: ignore

import onnx
from onnx import defs, helper
from typing import Optional, Text, Set, Dict, IO

_all_schemas = defs.get_all_schemas()


class AttrCoverage(object):
    def __init__(self):  # type: () -> None
        self.name = None  # type: Optional[Text]
        self.values = set()  # type: Set[Text]

    def add(self, attr):  # type: (onnx.AttributeProto) -> None
        assert self.name in [None, attr.name]
        self.name = attr.name
        value = helper.get_attribute_value(attr)
        # Turn list into tuple so we can put it into set
        # As value can be string, don't blindly turn `collections.Iterable`
        # into tuple.
        if isinstance(value, list):
            value = tuple(value)
Esempio n. 16
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from collections import defaultdict
import os

from tabulate import tabulate  # type: ignore

import onnx
from onnx import defs, helper
from typing import Optional, Text, Set, Dict, IO

_all_schemas = defs.get_all_schemas()


class AttrCoverage(object):
    def __init__(self):  # type: () -> None
        self.name = None  # type: Optional[Text]
        self.values = set()  # type: Set[Text]

    def add(self, attr):  # type: (onnx.AttributeProto) -> None
        assert self.name in [None, attr.name]
        self.name = attr.name
        value = helper.get_attribute_value(attr)
        # Turn list into tuple so we can put it into set
        # As value can be string, don't blindly turn `collections.Iterable`
        # into tuple.
        if isinstance(value, list):
            value = tuple(value)
Esempio n. 17
0
def gen_support_status(docs_dir, onnx_version, onnx_path,
                       onnx_tf_release_build):

    # set filename
    if onnx_tf_release_build:
        # get onnx-tf version from VERSION_NUMBER file
        version_dir = os.path.dirname(
            os.path.dirname(os.path.realpath('VERSION_NUMBER')))
        version_file = os.path.join(version_dir, 'VERSION_NUMBER')
        onnx_tf_version = subprocess.check_output('cat ' + version_file,
                                                  shell=True)
        onnx_tf_version = 'v' + onnx_tf_version.decode().strip('\n')
        filename = 'support_status_' + onnx_tf_version.replace('.',
                                                               '_') + '.md'
    else:  # onnx-tf = master
        # get onnx-tf commit id
        onnx_tf_commit_id = subprocess.check_output('git rev-parse HEAD',
                                                    shell=True)
        onnx_tf_commit_id = onnx_tf_commit_id.decode().strip('\n')
        onnx_tf_version = 'Master ( commit id: {} )'.format(onnx_tf_commit_id)
        filename = 'support_status.md'

    with open(os.path.join(docs_dir, filename), 'w') as status_file:
        status_file.write('# ONNX-Tensorflow Support Status\n')
        status_file.write('|||\n')
        status_file.write('|-:|:-|\n')
        status_file.write(
            '|ONNX-Tensorflow Version|{}|\n'.format(onnx_tf_version))

        # get onnx commit id
        if onnx_version == 'master':
            onnx_commit_id = subprocess.check_output('cd ' + onnx_path +
                                                     '; git rev-parse HEAD',
                                                     shell=True)
            onnx_commit_id = onnx_commit_id.decode().strip('\n')
            status_file.write(
                '|ONNX Version|Master ( commit id: {} )|\n'.format(
                    onnx_commit_id))
        else:
            status_file.write('|ONNX Version|{}|\n'.format(onnx_version))

        # get tf_version
        status_file.write('|Tensorflow Version|v{}|\n\n'.format(
            tf.__version__))

        # display the table legend
        status_file.write('Notes:\n')
        status_file.write('* Values that are new or updated from a ')
        status_file.write('previous opset version are in bold.\n')
        status_file.write('* -: not defined in corresponding ONNX ')
        status_file.write('opset version\n')
        status_file.write('* \*: the operator is deprecated\n')
        status_file.write('* :small_red_triangle:: not supported yet\n')
        status_file.write('* :small_orange_diamond:: partially supported\n')
        status_file.write('* the rest are all supported\n\n')

        # get oll onnx ops
        onnx_ops = {}
        for schema in defs.get_all_schemas():
            if schema.domain == '':  # only get onnx ops
                onnx_ops[schema.name] = {
                    'versions': [],
                    'deprecated':
                    schema.since_version if schema.deprecated else -1
                }
        for schema in defs.get_all_schemas_with_history():
            if schema.domain == '':  # only get onnx ops
                op = onnx_ops[schema.name]
                versions = op['versions']
                versions.append(schema.since_version)

        # get all onnx-tf supported ops
        onnx_tf_ops = opset_version.backend_opset_version
        onnx_tf_ops_ps = opset_version.backend_partial_support

        # get the cureent opset version
        current_opset = defs.onnx_opset_version()

        # setup table header
        status_file.write('||')
        for i in range(current_opset):
            status_file.write('|')
        status_file.write('\n|:-:|')
        for i in range(current_opset):
            status_file.write(':-:|')
        status_file.write('\n|**ONNX Operator**|')
        for opset in range(1, current_opset + 1):
            status_file.write('**Opset {}**|'.format(opset))

        ops_count = len(onnx_ops)
        # fill in data for the table
        for key, val in sorted(onnx_ops.items()):
            try:
                status_file.write('\n|{}|'.format(key))
                i = 0
                vers = val['versions']
                deprecated = val['deprecated']
                for opset in range(1, current_opset + 1):
                    if i <= len(vers) - 1:
                        lb = vers[i]
                        ub = vers[i + 1] if i < len(vers) - 1 else vers[i]
                        if opset < lb:
                            if i == 0:
                                status_file.write('-')
                        elif opset == lb:
                            status_file.write('**{}**'.format(lb))
                            if lb == deprecated:
                                status_file.write('\*')
                            elif lb not in onnx_tf_ops[key]:
                                status_file.write(':small_red_triangle:')
                                if opset == current_opset:
                                    ops_count -= 1
                            elif key in onnx_tf_ops_ps:
                                status_file.write(':small_orange_diamond:')
                        else:  # opset > lb
                            if opset < ub:
                                status_file.write('{}'.format(lb))
                                if lb == deprecated:
                                    status_file.write('\*')
                                elif lb not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                            elif opset == ub:
                                status_file.write('**{}**'.format(ub))
                                if ub == deprecated:
                                    status_file.write('\*')
                                elif ub not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                                i += 1
                            else:  #opset > ub
                                status_file.write('{}'.format(ub))
                                if ub == deprecated:
                                    status_file.write('\*')
                                elif ub not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                        status_file.write('|')
            except:
                # ops defined in onnx but not in opset_version.backend_opset_versionn
                status_file.write(':small_red_triangle:|')

        status_file.write(
            '\n\nONNX-TF Supported Operators / ONNX Operators: {} / {}'.format(
                ops_count, len(onnx_ops)))

        # display partial support footnote
        status_file.write('\n\nNotes:\n')
        index = 1
        for key in onnx_tf_ops_ps:
            status_file.write(
                str(index) + '. ' + key + ': ' + onnx_tf_ops_ps[key] + '\n')
            index += 1
Esempio n. 18
0
def main():
    backend_opset_dict = {}
    frontend_opset_dict = {}
    frontend_tf_opset_dict = {}

    for schema in defs.get_all_schemas():
        op_name = schema.name
        backend_opset_dict[op_name] = []
        frontend_opset_dict[op_name] = []

    version = 1
    while True:
        try:
            backend = (importlib.import_module(
                'backends.backend_v{}'.format(version)).TensorflowBackend)
            frontend = (importlib.import_module(
                'frontends.frontend_v{}'.format(version)).TensorflowFrontend)
        except:
            break

        # Register all tf ops in ONNX_TO_HANDLER
        tf_op_names = []
        onnx_to_handler = frontend.ONNX_TO_HANDLER.get(
            'frontend_v{}'.format(version), {})
        # for handler in frontend.ONNX_TO_HANDLER.values():
        for handler in onnx_to_handler.values():
            if isinstance(handler, list):
                tf_op_names.extend(list(map(op_name_to_lower, handler)))
            else:
                tf_op_names.append(op_name_to_lower(handler))

        for schema in defs.get_all_schemas():
            op_name = schema.name
            lower_op_name = op_name_to_lower(op_name)
            has_backend_handler = hasattr(backend, 'handle_' + lower_op_name)
            # Record only one version for trivial ops
            if has_backend_handler or (version == 1 and lower_op_name
                                       in ONNX_OP_TO_TF_OP.keys()):
                backend_opset_dict[op_name].append(version)

            # Register once if onnx op in ONNX_OP_TO_TF_OP_STR
            if version == 1 and schema.name in ONNX_OP_TO_TF_OP_STR and \
                ONNX_OP_TO_TF_OP_STR[schema.name] not in tf_op_names:
                tf_op_names.append(
                    op_name_to_lower(ONNX_OP_TO_TF_OP_STR[schema.name]))
                frontend_opset_dict[op_name].append(version)
            # Register if onnx op in ONNX_TO_HANDLER
            elif op_name in onnx_to_handler:
                frontend_opset_dict[op_name].append(version)
        for tf_op_name in tf_op_names:
            frontend_tf_opset_dict.setdefault(str(tf_op_name),
                                              []).append(version)

        version += 1

    with open('opset_version.py', 'w') as version_file:
        pp = pprint.PrettyPrinter(indent=4)
        version_file.write("backend_opset_version = {\n " +
                           pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_opset_version = {\n " +
                           pp.pformat(frontend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_tf_opset_version = {\n " +
                           pp.pformat(frontend_tf_opset_dict)[1:-1] + "\n}\n")