def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
     type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo)
     type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo)
     type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo)
     # types in kernel registration are ordered this way: input (T1), output (T3), depth (T2)
     key = (type0, type2, type1)
     self._triples.add(key)
    def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
        '''
        Process a Node and record info on the types used.
        :param node: Node from ORT format model
        :param value_name_to_typeinfo: Map of value names to TypeInfo instances
        '''
        optype = node.OpType().decode()
        domain = node.Domain().decode(
        ) or 'ai.onnx'  # empty domain defaults to ai.onnx

        key = _create_op_key(domain, optype)
        op_processor = self._get_op_processor(key)
        if op_processor:
            op_processor.process_node(node, value_name_to_typeinfo)
    def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
        for i in self._input_types.keys():
            if i >= node.InputsLength():
                # Some operators have fewer inputs in earlier versions where data that was as an attribute
                # become an input in later versions to allow it to be dynamically provided. Allow for that.
                # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
                # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
                #                    .format(node.OutputsLength(), self.name, o))
                pass
            else:
                type_str = value_name_to_typestr(node.Inputs(i),
                                                 value_name_to_typeinfo)
                self._input_types[i].add(type_str)

        for o in self._output_types.keys():
            # Don't know of any ops where the number of outputs changed across versions, so require a valid length
            if o >= node.OutputsLength():
                raise RuntimeError(
                    'Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
                    .format(node.OutputsLength(), self.name, o))

            type_str = value_name_to_typestr(node.Outputs(o),
                                             value_name_to_typeinfo)
            self._output_types[o].add(type_str)
Пример #4
0
    def _dump_node(self, node: fbs.Node):
        optype = node.OpType().decode()
        domain = node.Domain().decode(
        ) or 'ai.onnx'  # empty domain defaults to ai.onnx

        inputs = [
            node.Inputs(i).decode() for i in range(0, node.InputsLength())
        ]
        outputs = [
            node.Outputs(i).decode() for i in range(0, node.OutputsLength())
        ]
        print(f'{node.Index()}:{node.Name().decode()}({domain}:{optype}) '
              f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]')