def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo) type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo) type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo) # types in kernel registration are ordered this way: input (T1), output (T3), depth (T2) key = (type0, type2, type1) self._triples.add(key)
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): ''' Process a Node and record info on the types used. :param node: Node from ORT format model :param value_name_to_typeinfo: Map of value names to TypeInfo instances ''' optype = node.OpType().decode() domain = node.Domain().decode( ) or 'ai.onnx' # empty domain defaults to ai.onnx key = _create_op_key(domain, optype) op_processor = self._get_op_processor(key) if op_processor: op_processor.process_node(node, value_name_to_typeinfo)
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): for i in self._input_types.keys(): if i >= node.InputsLength(): # Some operators have fewer inputs in earlier versions where data that was as an attribute # become an input in later versions to allow it to be dynamically provided. Allow for that. # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.' # .format(node.OutputsLength(), self.name, o)) pass else: type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo) self._input_types[i].add(type_str) for o in self._output_types.keys(): # Don't know of any ops where the number of outputs changed across versions, so require a valid length if o >= node.OutputsLength(): raise RuntimeError( 'Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.' .format(node.OutputsLength(), self.name, o)) type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo) self._output_types[o].add(type_str)
def _dump_node(self, node: fbs.Node): optype = node.OpType().decode() domain = node.Domain().decode( ) or 'ai.onnx' # empty domain defaults to ai.onnx inputs = [ node.Inputs(i).decode() for i in range(0, node.InputsLength()) ] outputs = [ node.Outputs(i).decode() for i in range(0, node.OutputsLength()) ] print(f'{node.Index()}:{node.Name().decode()}({domain}:{optype}) ' f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]')