Пример #1
0
def main(args):  # type: (Type[Args]) -> None

    with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
        fout.write('## Supported Operators Data Types\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnxruntime/core/providers/cpu/cpu_execution_provider.cc) via "
            "[this script](/tools/python/gen_opkernel_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
        )
        opdef = rtpy.get_all_operator_schema()
        paramdict = {}
        for schema in opdef:
            inputs = schema.inputs
            domain = schema.domain
            if (domain == ''):
                domain = 'ai.onnx.ml'
            fullname = domain + '.' + schema.name
            paramstr = '('
            firstinput = True
            if inputs:
                for inp in inputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += ', '
                    paramstr += '*in* {}:**{}**'.format(inp.name, inp.typeStr)

            outputs = schema.outputs
            if outputs:
                for outp in outputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += ', '
                    paramstr += '*out* {}:**{}**'.format(
                        outp.name, outp.typeStr)

            paramstr += ')'
            paramset = paramdict.get(fullname, None)
            if paramset is None:
                paramdict[fullname] = set()

            paramdict[fullname].add(paramstr)

        index = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        for op in rtpy.get_all_opkernel_def():
            domain = op.domain
            if (domain == ''):
                domain = 'ai.onnx.ml'
            index[op.provider][domain][op.op_name].append(op)

        fout.write('\n')
        for provider, domainmap in sorted(index.items()):
            fout.write('\n\n## Operators implemented by ' + provider + '\n\n')
            fout.write(
                '| Op Name | Parameters | OpSet Version | Types Supported |\n')
            fout.write(
                '|---------|------------|---------------|-----------------|\n')
            for domain, namemap in sorted(domainmap.items()):
                fout.write('|**Operator Domain:** *' + domain + '*||||\n')
                for name, ops in sorted(namemap.items()):
                    version_type_index = defaultdict(lambda: defaultdict(set))
                    for op in ops:
                        for tname, tclist in op.type_constraints.items():
                            for c in tclist:
                                version_type_index[
                                    op.version_range][tname].add(c)

                    namefirsttime = True
                    for version_range, typemap in sorted(
                            version_type_index.items(),
                            key=lambda x: x[0],
                            reverse=True):
                        if (namefirsttime):
                            params = paramdict.get(domain + '.' + name, None)
                            fout.write('|' + name + '|' +
                                       format_param_strings(params) + '|')
                            namefirsttime = False
                        else:
                            fout.write('|||')
                        fout.write(format_version_range(version_range) + '|')
                        tnameindex = 0
                        for tname, tcset in sorted(typemap.items()):
                            tnameindex += 1
                            tclist = []
                            for tc in sorted(tcset):
                                tclist.append(tc)
                            fout.write('**' + tname + '** = ' +
                                       format_type_constraints(tclist))
                            if (tnameindex < len(typemap)):
                                fout.write('<br/> ')
                        fout.write('|\n')

                fout.write('| |\n| |\n')
Пример #2
0
def main(output_path: pathlib.Path, provider_filter: [str]):

    providers = expand_providers(provider_filter)

    with io.open(output_path, "w", newline="", encoding="utf-8") as fout:
        fout.write("## Supported Operators and Data Types\n")
        fout.write(
            "*This file is automatically generated from the registered kernels by "
            "[this script](https://github.com/microsoft/onnxruntime/blob/master/tools/python/gen_opkernel_doc.py).\n"
            "Do not modify directly.*\n\n")
        opdef = rtpy.get_all_operator_schema()
        paramdict = {}
        for schema in opdef:
            inputs = schema.inputs
            domain = schema.domain
            if domain == "":
                domain = "ai.onnx"
            fullname = domain + "." + schema.name
            paramstr = ""
            firstinput = True
            if inputs:
                for inp in inputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += "<br> "
                    paramstr += "*in* {}:**{}**".format(inp.name, inp.typeStr)

            outputs = schema.outputs
            if outputs:
                for outp in outputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += "<br> "
                    paramstr += "*out* {}:**{}**".format(
                        outp.name, outp.typeStr)

            paramstr += ""
            paramset = paramdict.get(fullname, None)
            if paramset is None:
                paramdict[fullname] = set()

            paramdict[fullname].add(paramstr)

        index = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        for op in rtpy.get_all_opkernel_def():
            domain = op.domain
            if domain == "":
                domain = "ai.onnx"
            index[op.provider][domain][op.op_name].append(op)

        # TOC
        fout.write("## Execution Providers\n\n")
        for provider in sorted(index.keys()):
            if providers and provider.lower() not in providers:
                continue
            fout.write("- [{}](#{})\n".format(provider, provider.lower()))
        fout.write("\n---------------")

        for provider, domainmap in sorted(index.items()):
            if providers and provider.lower() not in providers:
                continue

            fout.write('\n\n<a name="{}"/>\n\n'.format(provider.lower()))
            fout.write("## Operators implemented by {}\n\n".format(provider))
            fout.write(
                "| Op Name | Parameters | OpSet Version | Types Supported |\n")
            fout.write(
                "|---------|------------|---------------|-----------------|\n")
            for domain, namemap in sorted(domainmap.items()):
                fout.write("|**Operator Domain:** *" + domain + "*||||\n")
                for name, ops in sorted(namemap.items()):
                    version_type_index = defaultdict(lambda: defaultdict(set))
                    for op in ops:
                        for tname, tclist in op.type_constraints.items():
                            for c in tclist:
                                version_type_index[
                                    op.version_range][tname].add(c)

                    namefirsttime = True
                    for version_range, typemap in sorted(
                            version_type_index.items(),
                            key=lambda x: x[0],
                            reverse=True):
                        if namefirsttime:
                            params = paramdict.get(domain + "." + name, None)
                            fout.write("|" + name + "|" +
                                       format_param_strings(params) + "|")
                            namefirsttime = False
                        else:
                            fout.write("|||")
                        fout.write(format_version_range(version_range) + "|")
                        tnameindex = 0
                        for tname, tcset in sorted(typemap.items()):
                            tnameindex += 1
                            tclist = []
                            for tc in sorted(tcset):
                                tclist.append(tc)
                            fout.write("**" + tname + "** = " +
                                       format_type_constraints(tclist))
                            if tnameindex < len(typemap):
                                fout.write("<br/> ")
                        fout.write("|\n")

                fout.write("| |\n| |\n")