Example #1
0
    def _parse_attribute(self, attrs):
        """
        Parse attribute value from protobuf to dict.

        Note:
            Not all types of attribute have been implemented.
            Different types have their own methods to parse value.

        Args:
            attrs (onnx.AttributeProto): onnx.AttributeProto instance.

        """
        if not attrs:
            return
        for attribute in attrs:
            self.attribute_name_list.append(attribute.name)
            type_num = attribute.type
            # get attribute value by determining its type
            # Can Convert to np.array if needed
            if type_num == ONNX_TYPE_INTS:
                self.attribute_dict[attribute.name] = attribute.ints
            elif type_num == ONNX_TYPE_FLOATS:
                self.attribute_dict[attribute.name] = attribute.floats
            elif type_num == ONNX_TYPE_STRING:
                self.attribute_dict[attribute.name] = str(attribute.s, 'utf-8')
            elif type_num == ONNX_TYPE_INT:
                self.attribute_dict[attribute.name] = attribute.i
            elif type_num == ONNX_TYPE_FLOAT:
                self.attribute_dict[attribute.name] = attribute.f
            else:
                log.warning("WARNING: Attribute %s in Node %s not parsed.",
                            attribute.name, self.node_name)
Example #2
0
    def _convert_call(self, node, matched_api_name):
        """"Convert the call node."""
        new_node = None
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)
        warning_info = get_prompt_info(matched_api_name)
        if warning_info is None:
            warning_info = ''
        if matched_api_name in ALL_MAPPING:
            logger.info("Line %3d start converting API: %s", node.lineno,
                        api_name)
            new_code = self.mapping_api(node)
            if new_code != code:
                try:
                    new_node = pasta.parse(new_code).body[0].value
                    # find the first call name
                    new_api_name = new_code[:new_code.find('(')]
                    detail_msg = self._get_detail_prompt_msg(node, new_node)
                    if detail_msg:
                        warning_info = detail_msg + ' ' + warning_info
                except AttributeError:
                    new_node = pasta.parse(new_code).body[0]
                    new_api_name = new_code
                self._process_log.info(
                    node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS %
                    (api_name, new_api_name, warning_info))
        else:
            logger.warning("Line %3d: found unsupported API: %s%s",
                           node.lineno, api_name, warning_info)
            self._process_log.warning(
                node.lineno, node.col_offset,
                LOG_FMT_NOT_CONVERT % (api_name, warning_info))

        return new_node
Example #3
0
    def build_connection(self, src, tgt) -> NoReturn:
        """
        Build connection between source node and target node.

        Args:
            src (str): Source node name.
            tgt (str): Target node name.

        """
        # If src and tgt are the same node, src not in node_collection or
        # tgt not in node_collection, then skip this edge.
        if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
            if src.split(':')[0] not in self._nodes_collection:
                log.warning("Graph construct a self-loop node %s. Ignored.", src)
                return
        if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
            self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt)
        if src not in self._nodes_collection[tgt].precursor_nodes:
            self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src)
Example #4
0
    def convert_function(self, fun_name, fun, is_forward):
        """
        Convert a PyTorch function into MindSpore function.

        Args:
            fun_name (str): The str of function name.
            fun (func): The function to convert.
            is_forward (bool): If the function is defined in forward function in nn.Module in torch.

        Returns:
            dict, old code and converted code map if convert happens, else {}.
        """
        _, line_no = inspect.getsourcelines(fun)
        logger.info("Line %3d: start converting function %s()", line_no,
                    fun_name)

        code = inspect.getsource(fun)
        code_saved = copy.copy(code)

        i = 0
        while i < len(code):
            api_name = self.find_api(code, i, is_forward)
            if api_name:
                line_no1 = line_no + code[:i].count('\n')
                if api_name in ALL_MAPPING:
                    logger.info("Line %3d start converting API: %s", line_no1,
                                api_name)
                    code, i = self.convert_api(code, i, api_name)
                    self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(
                        line_no1, api_name)
                    continue
                if api_name in ALL_UNSUPPORTED:
                    warn_info = ". " + UNSUPPORTED_WARN_INFOS[
                        api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
                    logger.warning("Line %3d: found unsupported API: %s%s",
                                   line_no1, api_name, warn_info)
                    self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(
                        line_no1, api_name, warn_info)
            i += 1
        return {code_saved: code} if code_saved != code else {}
Example #5
0
 def decode(self, onnx_attribute_proto):
     """This func converts the onnx attribute proto into ms attribute proto."""
     self.ms_attribute_proto.name = onnx_attribute_proto.name
     ms_type = CONVERT_ATTRIBUTE_TYPE[onnx_attribute_proto.type]
     self.ms_attribute_proto.value.dtype = ms_type
     if ms_type == MSDataType.DT_INT64.value:
         self.ms_attribute_proto.value.int_val = onnx_attribute_proto.i
     elif ms_type == MSDataType.DT_FLOAT64.value:
         self.ms_attribute_proto.value.float_val = onnx_attribute_proto.f
     elif ms_type == MSDataType.DT_TENSOR.value:
         tp = TensorProto()
         tp.decode(onnx_attribute_proto.t)
         self.ms_attribute_proto.value.tensor_val.node_name = tp.ms_tensor_proto.node.name
         self.ms_attribute_proto.value.tensor_val.tensor_content = tp.ms_tensor_proto.node.tensor_content
         self.ms_attribute_proto.value.tensor_val.data_type = tp.ms_tensor_proto.node.data_type
         for d in tp.ms_tensor_proto.dims:
             self.ms_attribute_proto.value.tensor_val.dims.append(d)
     elif ms_type == MSDataType.DT_INTS64.value:
         for i in onnx_attribute_proto.ints:
             self.ms_attribute_proto.value.int_vals.append(i)
     else:
         log.warning(
             "MSGraph can not supply this data type when DataType equals to %s.",
             ms_type)
Example #6
0
    def visit_Call(self, node):
        """Callback function when visit AST tree"""
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)

        # The parent node first call is equal to this node, skip when parent node is replaced.
        # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
        # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
        # Access from the penultimate element in reverse order.
        for parent_node in self._stack[-2::-1]:
            if parent_node in self._new_call_nodes and pasta.dump(
                    parent_node).startswith(api_name):
                return
        parent = self._stack[-2]
        new_node = None
        new_code = code
        matched_api_name, match_case = self.match_api(
            node.func, self._is_forward_function)
        if match_case in [
                ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED
        ]:
            warning_info = get_prompt_info(matched_api_name)
            if warning_info is None:
                warning_info = ''
            if matched_api_name in ALL_MAPPING:
                logger.info("Line %3d start converting API: %s", node.lineno,
                            api_name)
                new_code = self.mapping_api(node)
                if new_code != code:
                    try:
                        new_node = pasta.parse(new_code).body[0].value
                        # find the first call name
                        new_api_name = new_code[:new_code.find('(')]
                    except AttributeError:
                        new_node = pasta.parse(new_code).body[0]
                        new_api_name = new_code
                    self._process_log.info(
                        node.lineno, node.col_offset,
                        LOG_FMT_CONVERT_WITH_TIPS %
                        (api_name, new_api_name, warning_info))
            else:
                logger.warning("Line %3d: found unsupported API: %s%s",
                               node.lineno, api_name, warning_info)
                self._process_log.warning(
                    node.lineno, node.col_offset,
                    LOG_FMT_NOT_CONVERT % (api_name, warning_info))

        elif match_case in [
                ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND
        ]:
            self._process_log.warning(node.lineno, node.col_offset,
                                      LOG_FMT_NOT_CONVERT % (api_name, ''))
        else:
            pass

        if parent and new_node:
            update_line_col = _LineColEditVisitor()
            update_line_col.update(new_node, node)
            pasta.ast_utils.replace_child(parent, node, new_node)
            self._new_call_nodes.append(new_node)

            node = new_node
            self._stack[-1] = node
        try:
            self.generic_visit(node)
        except Exception:
            logger.error('original code:%s, new code:%s',
                         code,
                         new_code,
                         exc_info=True)
            raise