class Asteria(): ''' Functions: 1. calculate the similarity between functions 2. calculate the similarity between asts 3. calculate the similarity between ast encodings ''' def __init__(self, checkpoint_path, model_selector, cuda=False): #cuda = True self.dh = DataHelper() # 数据库读取 self.checkpoint_path = checkpoint_path self.model_selector = model_selector self.cuda = cuda l.info("[I] Model Loading....") self.compute_app = Application(checkpoint_path, cuda=cuda, model_name=model_selector) l.info("[I] Model loaded...") def ast_encode_similarity(self, sources=[], targets=[], threshold=0): ''' :param sources:list: source asts :param targets:list: target asts :return: dict: key is function_name, value is a dict :{'rank':[], 'info':(function_name, elf_path, elf_file_name, caller, callee, ast_encode)} ''' result = defaultdict(dict) for (function_name, elf_path, elf_name, scaller, scallee, ast_encode), _ in tqdm(sources): res = [] pool = Pool(processes=cpu_count() - 2) for (tfunction_name, telf_path, telf_name, tcaller, tcallee, tast_encode), _ in targets: if tast_encode is None: print("%s encode not exits" % tfunction_name) res.append((pool.apply_async( self.compute_app.similarity_treeencoding_with_correction, (json.loads(ast_encode), json.loads(tast_encode), (scaller, scallee), (tcaller, tcallee))), tfunction_name, telf_path, telf_name)) pool.close() pool.join() similarity_list = [] for r in res: sim = r[0].get() if sim >= threshold: similarity_list.append(((r[1], r[2]), sim)) similarity_list.sort(key=lambda x: x[1], reverse=True) # 排序 result[function_name]['rank'] = similarity_list result[function_name]['info'] = (function_name, elf_path, telf_name) return result def prefilter(self, ast1, ast2): ''' :param ast1: :param ast2: :return: if ast1 and ast2 are too different , return 1. ''' c1 = ast1.num_children c2 = ast2.num_children if abs(c1 - c2) > 30: return 1 if c1 / c2 > 3 or c2 / c1 > 3: return 1 return 0 def ast_similarity(self, sources=[], targets=[], astfilter=None, threshold=0): ''' :param sources: list: source asts :param targets: list: target asts func_info:[function_name, elf_path, elf_file_name, caller, callee, ast_encode] :param astfilter: a filter function to filter out ast pairs which are too different. :return: dict: key {'rank':[], 'info':(function_name, elf_path, elf_file_name, caller, callee, ast_encode)} ''' result = {} N = len(sources) i = 0 TN = len(targets) astfilter = self.prefilter if astfilter: l.error("Filter Function is applied.") for s_func_info, s_ast in sources: i += 1 result[s_func_info[0]] = {'rank': '', 'info': ''} res = [] with tqdm(targets, desc="[%d] of %d" % (i, N), dynamic_ncols=True) as t: for func_info, t_ast in t: if astfilter and astfilter(s_ast, t_ast): res.append([func_info, 0]) else: res.append([ func_info, self.compute_app.similarity_tree_with_correction( s_ast, t_ast, [s_func_info[-3], s_func_info[-2]], [func_info[-3], func_info[-2]]) ]) res = list(filter(lambda x: x[1] > threshold, res)) res.sort(key=lambda x: x[1], reverse=True) # 排序 result[s_func_info[0]]['rank'] = res result[s_func_info[0]]['info'] = s_func_info return result def db_similarity(self, source_db, target_db, ast, threshold, start=-1, end=-1): ''' :param source_db: aught to be vulnerability database path :param target_db: firmware function database :param ast: True:直接使用ast进行计算相似度;False,使用ast的编码之后的向量进行相似度计算 :param threshold: float: 0~1 :param start/end: the position for select in sql limit :return: ''' source_asts = [] target_asts = [] elf_names = set() where_suffix = " limit 0,20" # the number of vulnerability functions does not exceeds 100 for func in list( self.dh.get_functions(source_db, where_suffix=where_suffix)): # limit vul function number source_asts.append(func) elf_names.add("'" + func[0][2].split('.')[0] + "%'") elf_files = " or ".join(elf_names) # where_suffix = " where elf_file_name like %s" % elf_files #l.info("[DB] the firmware select filter is %s" % where_suffix) where_suffix = "" for func in self.dh.get_functions(target_db, start=start, end=end, where_suffix=where_suffix): target_asts.append(func) if ast: return self.ast_similarity(source_asts, target_asts, threshold=threshold) else: return self.ast_encode_similarity(source_asts, target_asts, threshold=threshold)