def get_opset_status():
    ops = []
    onnx_ops = {}

    for op_schema in defs.get_all_schemas_with_history():
        domain = onnx_ops.setdefault(op_schema.domain, {})
        domain[op_schema.name] = op_schema.deprecated
    for domain in onnx_ops:
        onnx_ops[domain] = len(onnx_ops[domain])

    counts_by_domain = {}
    for handler in BackendHandler.__subclasses__():
        handler.check_cls()
        counts_by_domain.setdefault(handler.DOMAIN, [0, onnx_ops.get(handler.DOMAIN, 0)])
        counts_by_domain[handler.DOMAIN][0] += 1
        ops.append([
            handler.DOMAIN,
            handler.ONNX_OP,
            handler.get_versions(),
            handler.PS_DESCRIPTION
        ])
    ops.sort(key=lambda x: x[0] + x[1])
    ops.insert(0, [
        'Domain', 'Op', 'Versions', 'Notes'
    ])
    return ops, counts_by_domain
Exemplo n.º 2
0
def build_operator_schemas():
    # domain -> support level -> name -> [schema]
    index = defaultdict(lambda: defaultdict(lambda: defaultdict(
        list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
    for schema in defs.get_all_schemas_with_history():
        index[schema.domain][int(
            schema.support_level)][schema.name].append(schema)

    # Preprocess the Operator Schemas
    # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
    operator_schemas = list(
    )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
    exsting_ops = set()  # type: Set[Text]
    for domain, _supportmap in sorted(index.items()):
        if not should_render_domain(domain):
            continue

        processed_supportmap = list()
        for _support, _namemap in sorted(_supportmap.items()):
            processed_namemap = list()
            for n, unsorted_versions in sorted(_namemap.items()):
                versions = sorted(unsorted_versions,
                                  key=lambda s: s.since_version)
                schema = versions[-1]
                if schema.name in exsting_ops:
                    continue
                exsting_ops.add(schema.name)
                processed_namemap.append((n, schema, versions))
            processed_supportmap.append((_support, processed_namemap))
        operator_schemas.append((domain, processed_supportmap))
    return operator_schemas
Exemplo n.º 3
0
def build_operator_schemas():
    # domain -> support level -> name -> [schema]
    index = defaultdict(lambda: defaultdict(lambda: defaultdict(
        list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
    for schema in defs.get_all_schemas_with_history():
        index[schema.domain][int(
            schema.support_level)][schema.name].append(schema)

    # Preprocess the Operator Schemas
    # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
    operator_schemas = list(
    )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
    exsting_ops = set()  # type: Set[Text]
    for domain, _supportmap in sorted(index.items()):
        if not should_render_domain(domain):
            continue
        processed_supportmap = list()
        for _support, _namemap in sorted(_supportmap.items()):
            processed_namemap = list()
            for n, unsorted_versions in sorted(_namemap.items()):
                versions = sorted(unsorted_versions,
                                  key=lambda s: s.since_version)
                schema = versions[-1]
                if schema.name in exsting_ops:
                    continue

                if check_operation_version :
                    # Generate operation of the latest version of your onnx.
                    exsting_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))

                    # Add checks against version_dict
                    if schema.name not in version_dict :
                        print("Check-operation-version: Operation {} is new  with version {}"
                            .format(schema.name, schema.since_version))
                    elif schema.since_version >  version_dict[schema.name]:
                        print("Check-operation-version: Operation {}"
                            .format(schema.name)+
                            " has a newer version {} over old version {}"
                            .format(schema.since_version, version_dict[schema.name]))
                else:
                    # Generate operation according to the version in version_dict.
                    if schema.name not in version_dict :
                        continue
                    found = False
                    for schema in reversed(versions):
                        # Check the version number against the version_dict
                        if schema.since_version == version_dict[schema.name]:
                            exsting_ops.add(schema.name)
                            processed_namemap.append((n, schema, versions))
                            found = True
                            break
                    if not found:
                        print("Your onnx installation may be too old. "
                           "The desired version for operation {} is not found.".format(
                            schema.name))
                        sys.exit()
            processed_supportmap.append((_support, processed_namemap))
        operator_schemas.append((domain, processed_supportmap))
    return operator_schemas
Exemplo n.º 4
0
def getSchemas():
    schemas = [x for x in defs.get_all_schemas_with_history() if not x.deprecated]
    max_version = {}
    for x in schemas:
        if x.name in max_version:
            max_version[x.name] = max(max_version[x.name],x.since_version)
        else:
            max_version[x.name] = x.since_version
    return [x for x in schemas if x.since_version == max_version[x.name]]
Exemplo n.º 5
0
 def _build_op_version(self):
     res = {}
     for schema in get_all_schemas_with_history():
         dom = schema.domain
         name = schema.name
         vers = schema.since_version
         if (dom, name) not in res:
             res[dom, name] = set()
         res[dom, name].add(vers)
     self._op_versions = {}
     for k, v in res.items():
         self._op_versions[k] = list(sorted(v))
Exemplo n.º 6
0
def _register_all_schemas_with_history():
    """Register all schemas with history"""
    onnx_schemas = defs.get_all_schemas_with_history()
    name_domain_version_schema_map = defaultdict(lambda: defaultdict(dict))
    for s in onnx_schemas:
        schema = OnnxOpSchema.from_onnx_schema(s)
        name_domain_version_schema_map[schema.name][schema.domain][schema.since_version] = schema

    ordered_map = defaultdict(lambda: defaultdict(OrderedDict))
    for name, domain_version_schema_map in name_domain_version_schema_map.items():
        for domain, version_schema_map in domain_version_schema_map.items():
            ordered_map[name][domain] = OrderedDict(
                sorted(version_schema_map.items(), key=lambda x: -x[0])
            )
    return ordered_map
Exemplo n.º 7
0
def main(args):  # type: (Type[Args]) -> None
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> version -> [schema]
        dv_index = defaultdict(lambda: defaultdict(list)
                               )  # type: Dict[Text, Dict[int, List[OpSchema]]]
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain):
                continue

            if domain:
                s = '# {}\n'.format(domain)
                domain_prefix = '{}.'.format(domain)
            else:
                s = '# ai.onnx (default)\n'
                domain_prefix = ''

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(domain_prefix + schema.name,
                                                   schema.since_version)
                    s += '### <a name="{}"></a>**{}**</a>\n'.format(
                        name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.output, 'w', newline='') as fout:
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> support level -> name -> [schema]
        index = defaultdict(lambda: defaultdict(lambda: defaultdict(
            list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        fout.write('\n')

        # Table of contents
        for domain, supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            if domain:
                s = '* {}\n'.format(domain)
                domain_prefix = '{}.'.format(domain)
            else:
                s = '* ai.onnx (default)\n'
                domain_prefix = ''
            fout.write(s)

            for _, namemap in sorted(supportmap.items()):
                for n, unsorted_versions in sorted(namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        support_level_str(schema.support_level),
                        domain_prefix + n, domain_prefix + n)
                    fout.write(s)

        fout.write('\n')

        for domain, supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            if domain:
                s = '## {}\n'.format(domain)
                domain_prefix = '{}.'.format(domain)
            else:
                s = '## ai.onnx (default)\n'
                domain_prefix = ''
            fout.write(s)

            for _support, namemap in sorted(supportmap.items()):
                for op_type, unsorted_versions in sorted(namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]

                    # op_type
                    s = '### {}<a name="{}"></a><a name="{}">**{}**</a>\n'.format(
                        support_level_str(schema.support_level),
                        domain_prefix + op_type,
                        domain_prefix + op_type.lower(),
                        domain_prefix + op_type)

                    s += display_schema(schema, versions)

                    s += '\n\n'

                    if op_type in SNIPPETS:
                        s += '#### Examples\n\n'
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += '<details>\n'
                            s += '<summary>{}</summary>\n\n'.format(summary)
                            s += '```python\n{}\n```\n\n'.format(code)
                            s += '</details>\n'
                            s += '\n\n'
                    fout.write(s)
Exemplo n.º 8
0
def main(args):  # type: (Type[Args]) -> None
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> version -> [schema]
        dv_index = defaultdict(lambda: defaultdict(list)
                               )  # type: Dict[Text, Dict[int, List[OpSchema]]]
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain):
                continue

            s = '# {}\n'.format(display_domain_short(domain))

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(
                        format_name_with_domain(domain, schema.name),
                        schema.since_version)
                    s += ('### <a name="{}"></a>**{}**' +
                          (' (deprecated)' if schema.deprecated else '') +
                          '</a>\n').format(name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
        fout.write('## Operator Schemas\n')
        fout.write('#Table of Contents \n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )
        fout.write("|Operator |Input |Output |Type Constraint | Version |\n"
                   )  #create table
        fout.write("|-|-|-|-|-|-|\n")
        # domain -> support level -> name -> [schema]
        index = defaultdict(lambda: defaultdict(lambda: defaultdict(
            list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas = list(
        )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
        exsting_ops = set()  # type: Set[Text]
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            processed_supportmap = list()
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = list()
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    if schema.name in exsting_ops:
                        continue
                    exsting_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        for domain, supportmap in operator_schemas:

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    input_str = ""

                    if schema.min_input < schema.max_input:
                        input_str = "({} - {}) ".format(
                            display_number(schema.min_input),
                            display_number(schema.max_input))

                    isFirst = True
                    for input in schema.inputs:
                        if (isFirst):
                            isFirst = False
                        else:
                            input_str += "<br>"
                        input_str += input.description
                    if schema.min_output < schema.max_output:
                        input_str += ' ({} - {})'.format(
                            display_number(schema.min_output),
                            display_number(schema.max_output))
                    isFirstOut = True

                    output_str = ""

                    for output in schema.outputs:
                        if (isFirstOut):
                            isFirstOut = False
                        else:
                            output_str += "<br>"
                        output_str += output.description

                    tc_str = ""
                    if schema.type_constraints:
                        isFirstTc = True
                        for type_constraint in schema.type_constraints:
                            if (isFirstTc):
                                isFirstTc = False
                            else:
                                tc_str += "<br>"

                            allowedTypes = type_constraint.allowed_type_strs

                            allowedTypeStr = ', '.join(allowedTypes)

                            tc_str += allowedTypeStr
                            tc_str += type_constraint.description

                    version_str = ""
                    if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
                        version_str = '\nNo versioning maintained for experimental ops.'
                    else:
                        version_str = str(schema.since_version)
                        #s += ' of {}.\n'.format(display_domain(schema.domain))
                    op_link = '<a href="#{}">{}</a>'.format(op_type, op_type)
                    fout.write("|{}|{}|{}|{}|{}|\n".format(
                        format_name_with_domain(domain, op_link), input_str,
                        output_str, tc_str, version_str))
Exemplo n.º 9
0
 def __getattr__(self, item):
     return getattr(defs.get_all_schemas_with_history()[self.i], item)
Exemplo n.º 10
0
def main(args: Args) -> None:
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('<!--- SPDX-License-Identifier: Apache-2.0 -->\n')
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
            "\n"
            "For an operator input/output's differentiability, it can be differentiable,\n"
            "            non-differentiable, or undefined. If a variable's differentiability\n"
            "            is not specified, that variable has undefined differentiability.\n"
        )

        # domain -> version -> [schema]
        dv_index: Dict[Text, Dict[int, List[OpSchema]]] = defaultdict(
            lambda: defaultdict(list))
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain):
                continue

            s = '# {}\n'.format(display_domain_short(domain))

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(
                        format_name_with_domain(domain, schema.name),
                        schema.since_version)
                    s += ('### <a name="{}"></a>**{}**' +
                          (' (deprecated)' if schema.deprecated else '') +
                          '</a>\n').format(name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
        fout.write('<!--- SPDX-License-Identifier: Apache-2.0 -->\n')
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
            "\n"
            "For an operator input/output's differentiability, it can be differentiable,\n"
            "            non-differentiable, or undefined. If a variable's differentiability\n"
            "            is not specified, that variable has undefined differentiability.\n"
        )

        # domain -> support level -> name -> [schema]
        index: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]] = defaultdict(
            lambda: defaultdict(lambda: defaultdict(list)))
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        fout.write('\n')

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas: List[Tuple[Text, List[Tuple[int, List[Tuple[
            Text, OpSchema, List[OpSchema]]]]]]] = list()
        existing_ops: Set[Text] = set()
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            processed_supportmap = list()
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = list()
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    if schema.name in existing_ops:
                        continue
                    existing_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        # Table of contents
        for domain, supportmap in operator_schemas:
            s = '### {}\n'.format(display_domain_short(domain))
            fout.write(s)

            fout.write('|**Operator**|**Since version**|\n')
            fout.write('|-|-|\n')

            function_ops = list()
            for _, namemap in supportmap:
                for n, schema, versions in namemap:
                    if schema.has_function or schema.has_context_dependent_function:  # type: ignore
                        function_ops.append((n, schema, versions))
                        continue
                    s = '|{}<a href="#{}">{}</a>{}|{}|\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n),
                        ' (deprecated)' if schema.deprecated else '',
                        format_versions(versions))
                    fout.write(s)
            if len(function_ops):
                fout.write('|**Function**|**Since version**|\n')
                for n, schema, versions in function_ops:
                    s = '|{}<a href="#{}">{}</a>|{}|\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n),
                        format_versions(versions))
                    fout.write(s)

            fout.write('\n')

        fout.write('\n')

        for domain, supportmap in operator_schemas:
            s = '## {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    # op_type
                    s = ('### {}<a name="{}"></a><a name="{}">**{}**' +
                         (' (deprecated)' if schema.deprecated else '') +
                         '</a>\n').format(
                             support_level_str(schema.support_level),
                             format_name_with_domain(domain, op_type),
                             format_name_with_domain(domain, op_type.lower()),
                             format_name_with_domain(domain, op_type))

                    s += display_schema(schema, versions)

                    s += '\n\n'

                    if op_type in SNIPPETS:
                        s += '#### Examples\n\n'
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += '<details>\n'
                            s += '<summary>{}</summary>\n\n'.format(summary)
                            s += '```python\n{}\n```\n\n'.format(code)
                            s += '</details>\n'
                            s += '\n\n'
                    if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
                        s += '#### Sample Implementation\n\n'
                        s += '<details>\n'
                        s += '<summary>{}</summary>\n\n'.format(op_type)
                        s += '```python\n{}\n```\n\n'.format(
                            SAMPLE_IMPLEMENTATIONS[op_type.lower()])
                        s += '</details>\n'
                        s += '\n\n'

                    fout.write(s)
Exemplo n.º 11
0
def main(args):  # type: (Type[Args]) -> None
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> version -> [schema]
        dv_index = defaultdict(lambda: defaultdict(list)
                               )  # type: Dict[Text, Dict[int, List[OpSchema]]]
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain):
                continue

            s = '# {}\n'.format(display_domain_short(domain))

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(
                        format_name_with_domain(domain, schema.name),
                        schema.since_version)
                    s += ('### <a name="{}"></a>**{}**' +
                          (' (deprecated)' if schema.deprecated else '') +
                          '</a>\n').format(name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.fn_changelog, 'w', newline='') as fout:
        fout.write('## Function Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit function definitions.*\n"
        )

        if os.getenv('ONNX_ML'):
            all_functions = defs.get_functions(ONNX_ML_DOMAIN)
        else:
            all_functions = defs.get_functions('')

        changelog_versionmap = defaultdict(
            list)  # type: Dict[int, List[FunctionProto]]
        for fn_name, functions in sorted(all_functions.items()):
            for func in functions:
                changelog_versionmap[func.since_version].append(func)

        if os.getenv('ONNX_ML'):
            s = '## {}\n'.format(ONNX_ML_DOMAIN)
            domain_display_name = ONNX_ML_DOMAIN
            domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
        else:
            s = '# ai.onnx (default)\n'
            domain_display_name = 'ai.onnx (default)'
            domain_prefix = ''
        fout.write(s)

        for version, function_list in sorted(changelog_versionmap.items()):
            s = ""
            for function in function_list:
                s += '## Version {} of domain {}\n'.format(
                    version, domain_display_name)
                name_with_ver = '{}-{}'.format(domain_prefix + fn_name,
                                               function.since_version)
                s += '### <a name="{}"></a>**{}**</a>\n'.format(
                    name_with_ver, name_with_ver)
                available_versions = [
                    func.since_version for func in all_functions[function.name]
                ]
                s += display_function(function, available_versions,
                                      domain_prefix)
                s += '\n'
            fout.write(s)

    with io.open(args.output, 'w', newline='') as fout:
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> support level -> name -> [schema]
        index = defaultdict(lambda: defaultdict(lambda: defaultdict(
            list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        fout.write('\n')

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas = list(
        )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
        exsting_ops = set()  # type: Set[Text]
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            processed_supportmap = list()
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = list()
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    if schema.name in exsting_ops:
                        continue
                    exsting_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        # Table of contents
        for domain, supportmap in operator_schemas:
            s = '* {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for n, schema, versions in namemap:
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n))
                    fout.write(s)

        fout.write('\n')

        for domain, supportmap in operator_schemas:
            s = '## {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    # op_type
                    s = ('### {}<a name="{}"></a><a name="{}">**{}**' +
                         (' (deprecated)' if schema.deprecated else '') +
                         '</a>\n').format(
                             support_level_str(schema.support_level),
                             format_name_with_domain(domain, op_type),
                             format_name_with_domain(domain, op_type.lower()),
                             format_name_with_domain(domain, op_type))

                    s += display_schema(schema, versions)

                    s += '\n\n'

                    if op_type in SNIPPETS:
                        s += '#### Examples\n\n'
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += '<details>\n'
                            s += '<summary>{}</summary>\n\n'.format(summary)
                            s += '```python\n{}\n```\n\n'.format(code)
                            s += '</details>\n'
                            s += '\n\n'
                    fout.write(s)

    with io.open(args.function_output, 'w', newline='') as fout:
        fout.write('## Functions\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit function definitions.*\n"
        )

        if os.getenv('ONNX_ML'):
            all_functions = defs.get_functions(ONNX_ML_DOMAIN)
        else:
            all_functions = defs.get_functions('')

        if all_functions:
            if os.getenv('ONNX_ML'):
                s = '## {}\n'.format(ONNX_ML_DOMAIN)
                domain_prefix = '{}.'.format(ONNX_ML_DOMAIN)
            else:
                s = '## ai.onnx (default)\n'
                domain_prefix = ''
            fout.write(s)

            existing_functions = set()  # type: Set[Text]
            for function_name, functions in sorted(all_functions.items()):
                for function in sorted(functions,
                                       key=lambda s: s.since_version,
                                       reverse=True):
                    if function.name in existing_functions:
                        continue
                    existing_functions.add(function.name)
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        function_status_str(function.status),
                        domain_prefix + function.name,
                        domain_prefix + function.name)
                    fout.write(s)

                fout.write('\n')

            fout.write('\n\n')

            for function_name, functions in sorted(all_functions.items()):
                available_versions = [func.since_version for func in functions]
                function = sorted(functions,
                                  key=lambda s: s.since_version,
                                  reverse=True)[0]
                s = '### {}<a name="{}"></a><a name="{}">**{}**</a>\n'.format(
                    function_status_str(function.status),
                    domain_prefix + function.name,
                    domain_prefix + function.name.lower(),
                    domain_prefix + function.name)

                s += display_function(function, available_versions,
                                      domain_prefix)
                s += '\n\n'
                fout.write(s)
Exemplo n.º 12
0
def gen_support_status(docs_dir, onnx_version, onnx_path,
                       onnx_tf_release_build):

    # set filename
    if onnx_tf_release_build:
        # get onnx-tf version from VERSION_NUMBER file
        version_dir = os.path.dirname(
            os.path.dirname(os.path.realpath('VERSION_NUMBER')))
        version_file = os.path.join(version_dir, 'VERSION_NUMBER')
        onnx_tf_version = subprocess.check_output('cat ' + version_file,
                                                  shell=True)
        onnx_tf_version = 'v' + onnx_tf_version.decode().strip('\n')
        filename = 'support_status_' + onnx_tf_version.replace('.',
                                                               '_') + '.md'
    else:  # onnx-tf = master
        # get onnx-tf commit id
        onnx_tf_commit_id = subprocess.check_output('git rev-parse HEAD',
                                                    shell=True)
        onnx_tf_commit_id = onnx_tf_commit_id.decode().strip('\n')
        onnx_tf_version = 'Master ( commit id: {} )'.format(onnx_tf_commit_id)
        filename = 'support_status.md'

    with open(os.path.join(docs_dir, filename), 'w') as status_file:
        status_file.write('# ONNX-Tensorflow Support Status\n')
        status_file.write('|||\n')
        status_file.write('|-:|:-|\n')
        status_file.write(
            '|ONNX-Tensorflow Version|{}|\n'.format(onnx_tf_version))

        # get onnx commit id
        if onnx_version == 'master':
            onnx_commit_id = subprocess.check_output('cd ' + onnx_path +
                                                     '; git rev-parse HEAD',
                                                     shell=True)
            onnx_commit_id = onnx_commit_id.decode().strip('\n')
            status_file.write(
                '|ONNX Version|Master ( commit id: {} )|\n'.format(
                    onnx_commit_id))
        else:
            status_file.write('|ONNX Version|{}|\n'.format(onnx_version))

        # get tf_version
        status_file.write('|Tensorflow Version|v{}|\n\n'.format(
            tf.__version__))

        # display the table legend
        status_file.write('Notes:\n')
        status_file.write('* Values that are new or updated from a ')
        status_file.write('previous opset version are in bold.\n')
        status_file.write('* -: not defined in corresponding ONNX ')
        status_file.write('opset version\n')
        status_file.write('* \*: the operator is deprecated\n')
        status_file.write('* :small_red_triangle:: not supported yet\n')
        status_file.write('* :small_orange_diamond:: partially supported\n')
        status_file.write('* the rest are all supported\n\n')

        # get oll onnx ops
        onnx_ops = {}
        for schema in defs.get_all_schemas():
            if schema.domain == '':  # only get onnx ops
                onnx_ops[schema.name] = {
                    'versions': [],
                    'deprecated':
                    schema.since_version if schema.deprecated else -1
                }
        for schema in defs.get_all_schemas_with_history():
            if schema.domain == '':  # only get onnx ops
                op = onnx_ops[schema.name]
                versions = op['versions']
                versions.append(schema.since_version)

        # get all onnx-tf supported ops
        onnx_tf_ops = opset_version.backend_opset_version
        onnx_tf_ops_ps = opset_version.backend_partial_support

        # get the cureent opset version
        current_opset = defs.onnx_opset_version()

        # setup table header
        status_file.write('||')
        for i in range(current_opset):
            status_file.write('|')
        status_file.write('\n|:-:|')
        for i in range(current_opset):
            status_file.write(':-:|')
        status_file.write('\n|**ONNX Operator**|')
        for opset in range(1, current_opset + 1):
            status_file.write('**Opset {}**|'.format(opset))

        ops_count = len(onnx_ops)
        # fill in data for the table
        for key, val in sorted(onnx_ops.items()):
            try:
                status_file.write('\n|{}|'.format(key))
                i = 0
                vers = val['versions']
                deprecated = val['deprecated']
                for opset in range(1, current_opset + 1):
                    if i <= len(vers) - 1:
                        lb = vers[i]
                        ub = vers[i + 1] if i < len(vers) - 1 else vers[i]
                        if opset < lb:
                            if i == 0:
                                status_file.write('-')
                        elif opset == lb:
                            status_file.write('**{}**'.format(lb))
                            if lb == deprecated:
                                status_file.write('\*')
                            elif lb not in onnx_tf_ops[key]:
                                status_file.write(':small_red_triangle:')
                                if opset == current_opset:
                                    ops_count -= 1
                            elif key in onnx_tf_ops_ps:
                                status_file.write(':small_orange_diamond:')
                        else:  # opset > lb
                            if opset < ub:
                                status_file.write('{}'.format(lb))
                                if lb == deprecated:
                                    status_file.write('\*')
                                elif lb not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                            elif opset == ub:
                                status_file.write('**{}**'.format(ub))
                                if ub == deprecated:
                                    status_file.write('\*')
                                elif ub not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                                i += 1
                            else:  #opset > ub
                                status_file.write('{}'.format(ub))
                                if ub == deprecated:
                                    status_file.write('\*')
                                elif ub not in onnx_tf_ops[key]:
                                    status_file.write(':small_red_triangle:')
                                    if opset == current_opset:
                                        ops_count -= 1
                                elif key in onnx_tf_ops_ps:
                                    status_file.write(':small_orange_diamond:')
                        status_file.write('|')
            except:
                # ops defined in onnx but not in opset_version.backend_opset_versionn
                status_file.write(':small_red_triangle:|')

        status_file.write(
            '\n\nONNX-TF Supported Operators / ONNX Operators: {} / {}'.format(
                ops_count, len(onnx_ops)))

        # display partial support footnote
        status_file.write('\n\nNotes:\n')
        index = 1
        for key in onnx_tf_ops_ps:
            status_file.write(
                str(index) + '. ' + key + ': ' + onnx_tf_ops_ps[key] + '\n')
            index += 1
Exemplo n.º 13
0
def main(args):  # type: (Type[Args]) -> None
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> version -> [schema]
        dv_index = defaultdict(lambda: defaultdict(list)
                               )  # type: Dict[Text, Dict[int, List[OpSchema]]]
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain):
                continue

            s = '# {}\n'.format(display_domain_short(domain))

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(
                        format_name_with_domain(domain, schema.name),
                        schema.since_version)
                    s += '### <a name="{}"></a>**{}**</a>\n'.format(
                        name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.output, 'w', newline='') as fout:
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> support level -> name -> [schema]
        index = defaultdict(lambda: defaultdict(lambda: defaultdict(
            list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        fout.write('\n')

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas = list(
        )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
        exsting_ops = set()  # type: Set[Text]
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            processed_supportmap = list()
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = list()
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    if schema.name in exsting_ops:
                        continue
                    exsting_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        # Table of contents
        for domain, supportmap in operator_schemas:
            s = '* {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for n, schema, versions in namemap:
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n))
                    fout.write(s)

        fout.write('\n')

        for domain, supportmap in operator_schemas:
            s = '## {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    # op_type
                    s = '### {}<a name="{}"></a><a name="{}">**{}**</a>\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, op_type),
                        format_name_with_domain(domain, op_type.lower()),
                        format_name_with_domain(domain, op_type))

                    s += display_schema(schema, versions)

                    s += '\n\n'

                    if op_type in SNIPPETS:
                        s += '#### Examples\n\n'
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += '<details>\n'
                            s += '<summary>{}</summary>\n\n'.format(summary)
                            s += '```python\n{}\n```\n\n'.format(code)
                            s += '</details>\n'
                            s += '\n\n'
                    fout.write(s)
Exemplo n.º 14
0
def metadata():
    schemas = defs.get_all_schemas_with_history()
    schemas = sorted(schemas, key=lambda schema: schema.name)
    generate_json(schemas, '../src/onnx-metadata.json')
Exemplo n.º 15
0
                json_schema['type_constraints'].append({
                    'description': type_constraint.description,
                    'type_param_str': type_constraint.type_param_str,
                    'allowed_type_strs': type_constraint.allowed_type_strs
                })
        if schema.name in snippets:
            json_schema['snippets'] = []
            for summary, code in sorted(snippets[schema.name]):
                json_schema['snippets'].append({
                    'summary': summary,
                    'code': code
                })
        if schema.name in categories:
            json_schema['category'] = categories[schema.name]
        json_root.append({
            'name': schema.name,
            'schema': json_schema 
        })
    with io.open(json_file, 'w', newline='') as fout:
        json_root = json.dumps(json_root, sort_keys=True, indent=2)
        for line in json_root.splitlines():
            line = line.rstrip()
            if sys.version_info[0] < 3:
                line = unicode(line)
            fout.write(line)
            fout.write('\n')

if __name__ == '__main__':
    schemas = sorted(defs.get_all_schemas_with_history(), key=lambda schema: schema.name)
    generate_json(schemas, '../src/onnx-operator.json')
Exemplo n.º 16
0
def main(args):  # type: (Type[Args]) -> None
    with io.open(args.changelog, 'w', newline='') as fout:
        fout.write('## Operator Changelog\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> version -> [schema]
        dv_index = defaultdict(lambda: defaultdict(list)
                               )  # type: Dict[Text, Dict[int, List[OpSchema]]]
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write('\n')

        for domain, versionmap in sorted(dv_index.items()):
            print("domain", domain)
            if not should_render_domain(domain):
                continue

            s = '# {}\n'.format(display_domain_short(domain))

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += '## Version {} of {}\n'.format(version,
                                                    display_domain(domain))
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = '{}-{}'.format(
                        format_name_with_domain(domain, schema.name),
                        schema.since_version)
                    s += ('### <a name="{}"></a>**{}**' +
                          (' (deprecated)' if schema.deprecated else '') +
                          '</a>\n').format(name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema])
                    s += '\n'

            fout.write(s)

    with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
        fout.write('## Operator Schemas\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )

        # domain -> support level -> name -> [schema]
        index = defaultdict(lambda: defaultdict(lambda: defaultdict(
            list)))  # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
        for schema in defs.get_all_schemas_with_history():
            #print("check point 0", schema.name, schema.domain, schema.support_level)
            #gen_schema(schema)
            index[schema.domain][int(
                schema.support_level)][schema.name].append(schema)

        fout.write('\n')

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas = list(
        )  # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
        exsting_ops = set()  # type: Set[Text]
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain):
                continue

            processed_supportmap = list()
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = list()
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions,
                                      key=lambda s: s.since_version)
                    schema = versions[-1]
                    #print("check point 2", schema)
                    if schema.name in exsting_ops:
                        continue
                    exsting_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        # Table of contents
        for domain, supportmap in operator_schemas:
            s = '* {}\n'.format(display_domain_short(domain))
            fout.write(s)
            function_ops = list()
            for _, namemap in supportmap:
                for n, schema, versions in namemap:
                    if schema.has_function:  # type: ignore
                        function_ops.append((n, schema, versions))
                        continue
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n))
                    fout.write(s)
            if len(function_ops):
                fout.write('\n')
                fout.write('  **Operators with function registered:**\n')
                for n, schema, versions in function_ops:
                    s = '  * {}<a href="#{}">{}</a>\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n))
                    fout.write(s)

        fout.write('\n')
        tdfile = io.open(args.tdfile, 'w', newline='')
        fefile = io.open('op_build_table.inc', 'w', newline='')
        firstfunc = True

        fefile.write('    ' + 'if (OpName == "DUMMY") {\n')
        for domain, supportmap in operator_schemas:
            s = '## {}\n'.format(display_domain_short(domain))
            fout.write(s)

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    # op_type
                    #print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
                    gen_code(schema, fefile)

                    r = gen_schema(schema)
                    tdfile.write(r)
                    s = ('### {}<a name="{}"></a><a name="{}">**{}**' +
                         (' (deprecated)' if schema.deprecated else '') +
                         '</a>\n').format(
                             support_level_str(schema.support_level),
                             format_name_with_domain(domain, op_type),
                             format_name_with_domain(domain, op_type.lower()),
                             format_name_with_domain(domain, op_type))

                    s += display_schema(schema, versions)

                    s += '\n\n'

                    if op_type in SNIPPETS:
                        s += '#### Examples\n\n'
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += '<details>\n'
                            s += '<summary>{}</summary>\n\n'.format(summary)
                            s += '```python\n{}\n```\n\n'.format(code)
                            s += '</details>\n'
                            s += '\n\n'
                    if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
                        s += '#### Sample Implementation\n\n'
                        s += '<details>\n'
                        s += '<summary>{}</summary>\n\n'.format(op_type)
                        s += '```python\n{}\n```\n\n'.format(
                            SAMPLE_IMPLEMENTATIONS[op_type.lower()])
                        s += '</details>\n'
                        s += '\n\n'

                    fout.write(s)
        fefile.write('    }')
        fefile.close()
Exemplo n.º 17
0
                    type_constraint.allowed_type_strs
                })
        if schema.name in SNIPPETS:
            json_schema['snippets'] = []
            for summary, code in sorted(SNIPPETS[schema.name]):
                json_schema['snippets'].append({
                    'summary': summary,
                    'code': code
                })
        json_root.append({"name": schema.name, "schema": json_schema})
    with io.open(file, 'w', newline='') as fout:
        json_root = json.dumps(json_root, sort_keys=True, indent=2)
        for line in json_root.splitlines():
            line = line.rstrip()
            if sys.version_info[0] < 3:
                line = unicode(line)
            fout.write(line)
            fout.write('\n')


if __name__ == '__main__':

    schemas = sorted(defs.get_all_schemas_with_history(),
                     key=lambda schema: schema.name)
    generate_json(schemas, '../src/onnx-operator.json')

#        print(schema.name + "|" + schema.domain + "|" + str(schema.since_version))
#    sorted_ops = sorted(
#        (int(schema.support_level), op_type, schema)
#        for (op_type, schema) in defs.get_all_schemas().items())
Exemplo n.º 18
0
                    type_constraint.description,
                    'type_param_str':
                    type_constraint.type_param_str,
                    'allowed_type_strs':
                    type_constraint.allowed_type_strs
                })
        if schema.name in snippets:
            json_schema['examples'] = []
            for summary, code in sorted(snippets[schema.name]):
                json_schema['examples'].append({
                    'summary': summary,
                    'code': code
                })
        if schema.name in categories:
            json_schema['category'] = categories[schema.name]
        json_root.append({'name': schema.name, 'schema': json_schema})
    with io.open(json_file, 'w', newline='') as fout:
        json_root = json.dumps(json_root, sort_keys=True, indent=2)
        for line in json_root.splitlines():
            line = line.rstrip()
            if sys.version_info[0] < 3:
                line = unicode(line)
            fout.write(line)
            fout.write('\n')


if __name__ == '__main__':
    schemas = defs.get_all_schemas_with_history()
    schemas = sorted(schemas, key=lambda schema: schema.name)
    generate_json(schemas, '../src/onnx-metadata.json')
Exemplo n.º 19
0
#!/usr/bin/env python3
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os.path as path
from sys import argv
from onnx import defs

out_file = path.join(path.dirname(path.realpath(__file__)), 'opgen',
                     'onnxops.py')

onnx_ops = {}
for schema in defs.get_all_schemas_with_history():
    key = schema.name.lower()
    if schema.deprecated:
        continue
    if key not in onnx_ops or \
      onnx_ops[key].since_version < schema.since_version:
        onnx_ops[key] = schema

with open(out_file, 'wt') as fp:

    def write(s):
        fp.write(s)

    def writeline(s=''):
        fp.write(s + '\n')

    writeline(f'# AUTO-GENERATED CODE! - DO NOT EDIT!')
    writeline(f'# $ python {" ".join(argv)}')
Exemplo n.º 20
0
def metadata():
    schemas = defs.get_all_schemas_with_history()
    schemas = sorted(schemas, key=lambda schema: schema.name)
    json_file = os.path.join(os.path.dirname(__file__),
                             '../src/onnx-metadata.json')
    generate_json(schemas, json_file)