def _parse_attribute(self, attrs): """ Parse attribute value from protobuf to dict. Note: Not all types of attribute have been implemented. Different types have their own methods to parse value. Args: attrs (onnx.AttributeProto): onnx.AttributeProto instance. """ if not attrs: return for attribute in attrs: self.attribute_name_list.append(attribute.name) type_num = attribute.type # get attribute value by determining its type # Can Convert to np.array if needed if type_num == ONNX_TYPE_INTS: self.attribute_dict[attribute.name] = attribute.ints elif type_num == ONNX_TYPE_FLOATS: self.attribute_dict[attribute.name] = attribute.floats elif type_num == ONNX_TYPE_STRING: self.attribute_dict[attribute.name] = str(attribute.s, 'utf-8') elif type_num == ONNX_TYPE_INT: self.attribute_dict[attribute.name] = attribute.i elif type_num == ONNX_TYPE_FLOAT: self.attribute_dict[attribute.name] = attribute.f else: log.warning("WARNING: Attribute %s in Node %s not parsed.", attribute.name, self.node_name)
def _convert_call(self, node, matched_api_name): """"Convert the call node.""" new_node = None code = pasta.dump(node) api_name = pasta.dump(node.func) warning_info = get_prompt_info(matched_api_name) if warning_info is None: warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: try: new_node = pasta.parse(new_code).body[0].value # find the first call name new_api_name = new_code[:new_code.find('(')] detail_msg = self._get_detail_prompt_msg(node, new_node) if detail_msg: warning_info = detail_msg + ' ' + warning_info except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code self._process_log.info( node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) else: logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) self._process_log.warning( node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) return new_node
def build_connection(self, src, tgt) -> NoReturn: """ Build connection between source node and target node. Args: src (str): Source node name. tgt (str): Target node name. """ # If src and tgt are the same node, src not in node_collection or # tgt not in node_collection, then skip this edge. if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: if src.split(':')[0] not in self._nodes_collection: log.warning("Graph construct a self-loop node %s. Ignored.", src) return if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) if src not in self._nodes_collection[tgt].precursor_nodes: self._nodes_collection[tgt.split(':')[0]].precursor_nodes.append(src)
def convert_function(self, fun_name, fun, is_forward): """ Convert a PyTorch function into MindSpore function. Args: fun_name (str): The str of function name. fun (func): The function to convert. is_forward (bool): If the function is defined in forward function in nn.Module in torch. Returns: dict, old code and converted code map if convert happens, else {}. """ _, line_no = inspect.getsourcelines(fun) logger.info("Line %3d: start converting function %s()", line_no, fun_name) code = inspect.getsource(fun) code_saved = copy.copy(code) i = 0 while i < len(code): api_name = self.find_api(code, i, is_forward) if api_name: line_no1 = line_no + code[:i].count('\n') if api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", line_no1, api_name) code, i = self.convert_api(code, i, api_name) self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format( line_no1, api_name) continue if api_name in ALL_UNSUPPORTED: warn_info = ". " + UNSUPPORTED_WARN_INFOS[ api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format( line_no1, api_name, warn_info) i += 1 return {code_saved: code} if code_saved != code else {}
def decode(self, onnx_attribute_proto): """This func converts the onnx attribute proto into ms attribute proto.""" self.ms_attribute_proto.name = onnx_attribute_proto.name ms_type = CONVERT_ATTRIBUTE_TYPE[onnx_attribute_proto.type] self.ms_attribute_proto.value.dtype = ms_type if ms_type == MSDataType.DT_INT64.value: self.ms_attribute_proto.value.int_val = onnx_attribute_proto.i elif ms_type == MSDataType.DT_FLOAT64.value: self.ms_attribute_proto.value.float_val = onnx_attribute_proto.f elif ms_type == MSDataType.DT_TENSOR.value: tp = TensorProto() tp.decode(onnx_attribute_proto.t) self.ms_attribute_proto.value.tensor_val.node_name = tp.ms_tensor_proto.node.name self.ms_attribute_proto.value.tensor_val.tensor_content = tp.ms_tensor_proto.node.tensor_content self.ms_attribute_proto.value.tensor_val.data_type = tp.ms_tensor_proto.node.data_type for d in tp.ms_tensor_proto.dims: self.ms_attribute_proto.value.tensor_val.dims.append(d) elif ms_type == MSDataType.DT_INTS64.value: for i in onnx_attribute_proto.ints: self.ms_attribute_proto.value.int_vals.append(i) else: log.warning( "MSGraph can not supply this data type when DataType equals to %s.", ms_type)
def visit_Call(self, node): """Callback function when visit AST tree""" code = pasta.dump(node) api_name = pasta.dump(node.func) # The parent node first call is equal to this node, skip when parent node is replaced. # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting. # Access from the penultimate element in reverse order. for parent_node in self._stack[-2::-1]: if parent_node in self._new_call_nodes and pasta.dump( parent_node).startswith(api_name): return parent = self._stack[-2] new_node = None new_code = code matched_api_name, match_case = self.match_api( node.func, self._is_forward_function) if match_case in [ ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED ]: warning_info = get_prompt_info(matched_api_name) if warning_info is None: warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: try: new_node = pasta.parse(new_code).body[0].value # find the first call name new_api_name = new_code[:new_code.find('(')] except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code self._process_log.info( node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) else: logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) self._process_log.warning( node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) elif match_case in [ ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND ]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) else: pass if parent and new_node: update_line_col = _LineColEditVisitor() update_line_col.update(new_node, node) pasta.ast_utils.replace_child(parent, node, new_node) self._new_call_nodes.append(new_node) node = new_node self._stack[-1] = node try: self.generic_visit(node) except Exception: logger.error('original code:%s, new code:%s', code, new_code, exc_info=True) raise