def __call__(self, parser, raw_node, node_scope=''): nndct_node = Node(name=get_full_name(node_scope, raw_node.name), dtype=convert_dtype(raw_node.dtype), idx=raw_node.idx) nndct_node.raw_kind = raw_node.kind nndct_node.schema = raw_node.schema nndct_node.is_custom_extension = raw_node.is_custom_pyop nndct_node.caller = raw_node.pyobj blob_tensor_convertor = TensorConvertor() for op in raw_node.outputs: nndct_tensor = blob_tensor_convertor(node_scope, op) nndct_tensor.node = nndct_node nndct_node.out_tensors.append(nndct_tensor) parser.visited_blob_tensors[op.name] = nndct_tensor for ip in raw_node.flatten_inputs: if ip.name in parser.visited_blob_tensors: nndct_node.in_tensors.append( parser.visited_blob_tensors[ip.name]) elif ip.name in parser.visited_param_tensors: parser.node_params[nndct_node].append( parser.visited_param_tensors[ip.name]) nndct_node.in_tensors.append( parser.visited_param_tensors[ip.name]) if not raw_node.inputs: parser.node_input_args[nndct_node].append( nndct_node.out_tensors[0]) else: parser.node_input_args[nndct_node].extend( [parser.get_nndct_value(i) for i in raw_node.inputs]) return nndct_node
def _convert_node(self, raw_node, scope=None): if scope is None: assert self.cur_graph node_scope = self.cur_graph.name else: node_scope = scope nndct_node = Node( name=get_full_name(node_scope, raw_node.name), dtype=self.convert_dtype(raw_node.dtype), ) nndct_node.source_range = raw_node.source_range nndct_node.scope_name = raw_node.scope_name if nndct_node.name in self.cur_graph: return self.cur_graph.node(nndct_node.name) # nndct_node.raw_kind = raw_node.kind # self.converted_node.add(raw_node) nndct_node.schema = raw_node.schema nndct_node.is_custom_extension = raw_node.is_custom_pyop nndct_node.caller = raw_node.pyobj nndct_node.owning_block = self.cur_block nndct_node.owning_graph = self.cur_graph for out in raw_node.outputs: full_name = get_full_name(node_scope, out.name) if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name): nndct_node.add_out_tensor(self.cur_graph.tensor(full_name)) else: nndct_tensor = self._convert_tensor(out, node_scope) nndct_node.add_out_tensor(nndct_tensor) for ip in raw_node.flatten_inputs: full_name = get_full_name(node_scope, ip.name) if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name): nndct_node.add_in_tensor(self.cur_graph.tensor(full_name)) elif not raw_node.outputs: # For Return node nndct_tensor = self._convert_tensor(ip, node_scope) nndct_node.add_in_tensor(nndct_tensor) if self.cur_graph and full_name in self.cur_graph.param_names(): self.node_params[nndct_node].append( self.cur_graph.tensor(full_name)) #from ipdb import set_trace #set_trace() node_input_args = [] if not raw_node.inputs: node_input_args.extend( [self.get_nndct_value(i) for i in raw_node.outputs]) else: node_input_args.extend( [self.get_nndct_value(i) for i in raw_node.inputs]) nndct_node.op = self._create_op(raw_node.kind, nndct_node, node_input_args) return nndct_node