예제 #1
0
    def __call__(self, parser, raw_node, node_scope=''):
        nndct_node = Node(name=get_full_name(node_scope, raw_node.name),
                          dtype=convert_dtype(raw_node.dtype),
                          idx=raw_node.idx)

        nndct_node.raw_kind = raw_node.kind
        nndct_node.schema = raw_node.schema
        nndct_node.is_custom_extension = raw_node.is_custom_pyop
        nndct_node.caller = raw_node.pyobj

        blob_tensor_convertor = TensorConvertor()
        for op in raw_node.outputs:
            nndct_tensor = blob_tensor_convertor(node_scope, op)
            nndct_tensor.node = nndct_node
            nndct_node.out_tensors.append(nndct_tensor)
            parser.visited_blob_tensors[op.name] = nndct_tensor

        for ip in raw_node.flatten_inputs:
            if ip.name in parser.visited_blob_tensors:
                nndct_node.in_tensors.append(
                    parser.visited_blob_tensors[ip.name])
            elif ip.name in parser.visited_param_tensors:
                parser.node_params[nndct_node].append(
                    parser.visited_param_tensors[ip.name])
                nndct_node.in_tensors.append(
                    parser.visited_param_tensors[ip.name])

        if not raw_node.inputs:
            parser.node_input_args[nndct_node].append(
                nndct_node.out_tensors[0])
        else:
            parser.node_input_args[nndct_node].extend(
                [parser.get_nndct_value(i) for i in raw_node.inputs])

        return nndct_node
예제 #2
0
    def _convert_node(self, raw_node, scope=None):
        if scope is None:
            assert self.cur_graph
            node_scope = self.cur_graph.name
        else:
            node_scope = scope

        nndct_node = Node(
            name=get_full_name(node_scope, raw_node.name),
            dtype=self.convert_dtype(raw_node.dtype),
        )
        nndct_node.source_range = raw_node.source_range
        nndct_node.scope_name = raw_node.scope_name
        if nndct_node.name in self.cur_graph:
            return self.cur_graph.node(nndct_node.name)

        # nndct_node.raw_kind = raw_node.kind
        # self.converted_node.add(raw_node)
        nndct_node.schema = raw_node.schema
        nndct_node.is_custom_extension = raw_node.is_custom_pyop
        nndct_node.caller = raw_node.pyobj
        nndct_node.owning_block = self.cur_block
        nndct_node.owning_graph = self.cur_graph
        for out in raw_node.outputs:
            full_name = get_full_name(node_scope, out.name)
            if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name):
                nndct_node.add_out_tensor(self.cur_graph.tensor(full_name))
            else:
                nndct_tensor = self._convert_tensor(out, node_scope)
                nndct_node.add_out_tensor(nndct_tensor)

        for ip in raw_node.flatten_inputs:
            full_name = get_full_name(node_scope, ip.name)
            if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name):
                nndct_node.add_in_tensor(self.cur_graph.tensor(full_name))
            elif not raw_node.outputs:
                # For Return node
                nndct_tensor = self._convert_tensor(ip, node_scope)
                nndct_node.add_in_tensor(nndct_tensor)

            if self.cur_graph and full_name in self.cur_graph.param_names():
                self.node_params[nndct_node].append(
                    self.cur_graph.tensor(full_name))

        #from ipdb import set_trace
        #set_trace()

        node_input_args = []
        if not raw_node.inputs:
            node_input_args.extend(
                [self.get_nndct_value(i) for i in raw_node.outputs])
        else:
            node_input_args.extend(
                [self.get_nndct_value(i) for i in raw_node.inputs])

        nndct_node.op = self._create_op(raw_node.kind, nndct_node,
                                        node_input_args)

        return nndct_node