Exemplo n.º 1
0
    def _dump_node(self, node: fbs.Node):
        optype = node.OpType().decode()
        domain = node.Domain().decode(
        ) or 'ai.onnx'  # empty domain defaults to ai.onnx

        inputs = [
            node.Inputs(i).decode() for i in range(0, node.InputsLength())
        ]
        outputs = [
            node.Outputs(i).decode() for i in range(0, node.OutputsLength())
        ]
        print(f'{node.Index()}:{node.Name().decode()}({domain}:{optype}) '
              f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]')
    def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
        for i in self._input_types.keys():
            if i >= node.InputsLength():
                # Some operators have fewer inputs in earlier versions where data that was as an attribute
                # become an input in later versions to allow it to be dynamically provided. Allow for that.
                # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
                # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
                #                    .format(node.OutputsLength(), self.name, o))
                pass
            else:
                type_str = value_name_to_typestr(node.Inputs(i),
                                                 value_name_to_typeinfo)
                self._input_types[i].add(type_str)

        for o in self._output_types.keys():
            # Don't know of any ops where the number of outputs changed across versions, so require a valid length
            if o >= node.OutputsLength():
                raise RuntimeError(
                    'Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
                    .format(node.OutputsLength(), self.name, o))

            type_str = value_name_to_typestr(node.Outputs(o),
                                             value_name_to_typeinfo)
            self._output_types[o].add(type_str)