def prepare_functions(self): for line in open(self._reg_dec_file_path, 'r'): m = re.match(r'\s*([^\s].*); //\s+(.*)', line) if not m: continue cpp_func_sig = m.group(1).replace('at::', '').replace('c10::', '') aten_func_sig_literal = m.group(2) aten_func_sig = aten_func_sig_literal if "schema" in aten_func_sig_literal and "compound" in aten_func_sig_literal: res = json.loads(aten_func_sig_literal) aten_func_sig = res["schema"] if not self.is_tensor_api(cpp_func_sig): continue try: cpp_sig = CPPSig(cpp_func_sig) if self.is_bypass_func(cpp_sig): continue aten_sig = AtenSig(aten_func_sig) self.cross_correct_sig(cpp_sig, aten_sig) self._sigs.append( (cpp_sig, aten_sig, cpp_func_sig, aten_func_sig)) except Exception as e: self._err_info.append((cpp_func_sig, str(e))) print('Error parsing "{}": {}'.format(cpp_func_sig, e), file=sys.stderr) with open(self._func_file_path, 'r') as ff: self._func_data = ff.read() print('Extracted {} functions ({} errors) from {}'.format( len(self._sigs), len(self._err_info), self._reg_dec_file_path), file=sys.stderr) assert len(self._err_info) == 0
def prepare_functions(self): # Parse SparseCPUType.h _sparse_sig_strs = [] for line in open(self._sparse_dec_file_path, 'r'): m = re.match(r'\s*([^\s].*\));', line) if not m: continue cpp_func_sig_str = m.group(1) _sparse_sig_strs.append(cpp_func_sig_str) # print(cpp_func_sig_str) # print("********************") # Parse SparseAttrType.h with open(self._sparse_attr_file_path, 'r') as ff: self._sparse_attr_data = ff.read() # Parse Functions.h with open(self._func_file_path, 'r') as ff: self._func_data = ff.read() # Parse Registration declartion.h for line in open(self._reg_dec_file_path, 'r'): m = re.match(r'\s*([^\s].*); //\s+(.*)', line) if not m: continue cpp_func_sig = m.group(1).replace('at::', '').replace('c10::', '') aten_func_sig_literal = m.group(2) aten_func_sig = aten_func_sig_literal if "schema" in aten_func_sig_literal and "compound" in aten_func_sig_literal: res = json.loads(aten_func_sig_literal) aten_func_sig = res["schema"] if not self.is_tensor_api(cpp_func_sig): continue try: cpp_sig = CPPSig(cpp_func_sig) if self.is_bypass_func(cpp_sig): continue for sparse_cpp_sig_str in _sparse_sig_strs: if sparse_cpp_sig_str.find("clone") >= 0 and cpp_func_sig.find("clone") >= 0: print("{} {}".format(sparse_cpp_sig_str, cpp_func_sig)) if sparse_cpp_sig_str.replace(' ', '') == cpp_func_sig.replace(' ', ''): sparse_sig = CPPSig(sparse_cpp_sig_str) sparse_sig.is_tensor_member_func = self.is_tensor_member_function(sparse_sig.def_name) aten_sig = AtenSig(aten_func_sig) self.cross_correct_sig(sparse_sig, aten_sig) self._sigs.append((sparse_sig, aten_sig, sparse_cpp_sig_str, aten_func_sig)) else: continue except Exception as e: self._err_info.append((cpp_func_sig, str(e))) print('Error parsing "{}": {}'.format(cpp_func_sig, e), file=sys.stderr) print('Extracted {} functions ({} errors) from {}'.format( len(self._sigs), len(self._err_info), self._reg_dec_file_path), file=sys.stderr) assert len(self._err_info) == 0