Пример #1
0
    def visit(self, root: ast.AST):
        """AST 任意类型结点的访问入口
        为所有的结点添加 visited 属性,若已访问过则跳过
        """

        # 构建父结点表
        if not self.parent_ok:
            for node in ast.walk(root):
                for child in ast.iter_child_nodes(node):
                    if isinstance(child,
                                  (ast.expr_context, ast.boolop, ast.unaryop,
                                   ast.cmpop, ast.operator)):
                        continue
                    if child not in self.parent_table:
                        self.parent_table[child] = node
                    try:
                        assert self.parent_table[child] == node
                    except AssertionError:
                        print("CHILD: {}".format(child))
                        print("PARENT: {}".format(self.parent_table[child]))
                        print("NODE: {}".format(node))
                        raise
            logger.debug("Build parent table OK")
            self.parent_ok = True

        if not hasattr(root, 'visited'):
            super().visit(root)
            root.visited = True
Пример #2
0
    def visit_Call(self, node: ast.Call):
        """AST Call 类型结点访问入口
        需要重点关注:函数调用时传递的参数是否存在与输入相关的字符串常量
        """

        self.in_call = True

        self.related_input_name = None
        if hasattr(node.func, 'value') and isinstance(node.func.value,
                                                      ast.Name):
            func_value_name = node.func.value.id
            for ipt_name, var_names in self.var_input_maps[-1].items():
                if func_value_name in var_names:
                    self.related_input_name = ipt_name
        self.related_str_constant = None
        self.generic_visit(node)

        if self.related_str_constant and self.related_input_name:
            logger.debug(
                "   call: str constant '{}' related to input '{}'".format(
                    self.related_str_constant, self.related_input_name))
            self.str_input_maps[-1][self.related_input_name].add(
                self.related_str_constant)

        self.in_call = False
Пример #3
0
 def _visit(target_node: ast.AST):
     self.related_input_name = None
     self.visit(target_node)
     if self.related_input_name and len(assign_targets_ids) > 0:
         logger.debug(
             "Assign targets {} may related to input '{}'".format(
                 assign_targets_ids, self.related_input_name))
         self.var_input_maps[-1][self.related_input_name] |= set(
             assign_targets_ids)
Пример #4
0
    def collect(self, times=300, corpus=[]):
        """执行指定次数,收集初始测试用例

        :param times: 运行次数,默认为 10 次
        :param corpus: 用户指定的初始测试用例
        """

        self.data_pool.clear()
        self.data_pool.extend(corpus)

        # 静态分析阶段,提取主程序字符串常量
        self.curr_code, self.curr_ast, self.curr_sde_visitor = self.load_code(self.entry_path)
        # visitor = SDEVisitor()
        # visitor.visit(self.curr_ast)
        # self.data_pool.update(set(visitor.get_all_const_str()))
        
        if len(self.data_pool) < 1:
            self.data_pool.append('a')
        logger.debug("Initial data_pool = {}".format(self.data_pool))
        cnt = 0
        
        self.tracer_debugger = sys.gettrace()
        self.cov.start()
        self.tracer_cov = sys.gettrace()
        
        sys.settrace(self.trace_dispatch)
        threading.settrace(self.trace_dispatch)

        while cnt < times and len(self.data_pool) > 0:
            selected_idx = random.randint(0, len(self.data_pool) - 1)
            selected_ipt = self.data_pool[selected_idx]
            logger.info("Using input [{}] to run the test program".format(selected_ipt))
            self.run_once(selected_ipt)

            # 先执行,后 pop 掉本次执行使用的数据,避免重复
            self.data_pool.pop(selected_idx)
            if not self.quitting:
                self.used_data_pool.append(selected_ipt)
            else:
                logger.error("Input data \033[1;33m{}\033[0m caused a crash!".format(selected_ipt))
            cnt += 1
        
        self.cov.stop()
        self.cov.save()
        # logger.info("Generating Coverage report")
        # self.cov.report()

        sys.settrace(self.tracer_debugger)
        threading.settrace(self.tracer_debugger)
Пример #5
0
    def visit_Compare(self, node: ast.Compare):
        """AST Compare 类型结点访问入口
        需要重点关注:比较的目标中是否存在与输入相关的字符串常量
        """

        self.in_compare = True

        self.related_input_name = None
        self.related_str_constant = None
        self.generic_visit(node)

        if self.related_str_constant and self.related_input_name:
            logger.debug(
                "compare: str constant '{}' related to input '{}'".format(
                    self.related_str_constant, self.related_input_name))
            self.str_input_maps[-1][self.related_input_name].add(
                self.related_str_constant)

        self.in_compare = False
Пример #6
0
    def load_code(self, path):
        """读取源代码并进行缓存

        TODO: 缓存使用 LRU 等算法进行优化
        :param path: 源代码文件路径
        :return: 源代码字符串,解析后的 AST 以及 SDEVisitor 对象
        """
        code_lis = []
        curr_ast = None
        if path in self.code_cache:
            code_lis = self.code_cache[path]
            curr_ast = self.ast_cache[path]
            visitor2 = self.visitor_cache[path]
        else:
            logger.info("Detected new source file [{}]".format(path))
            with open(path, 'r') as f:
                code_src = f.read()
            code_lis = code_src.replace('\r\n', '\n').split('\n')
            self.code_cache[path] = code_lis

            logger.info("Building AST...")
            curr_ast = ast.parse(code_src, mode='exec')
            self.ast_cache[path] = curr_ast

            # 将 LineVisitor 解析的 line_node 添加到 curr_ast 对象中
            visitor = LineVisitor()
            visitor.visit(curr_ast)
            curr_ast.line_node = visitor.get_line_node()

            # 静态分析提取敏感数据
            visitor2 = SDEVisitor()
            visitor2.visit(curr_ast)
            self.visitor_cache[path] = visitor2
            new_data_pool: set = set(visitor2.get_all_const_str()) - set(self.data_pool)
            logger.debug("New data detected ({}): {}".format(len(new_data_pool), new_data_pool))
            self.data_pool.extend(new_data_pool)
            # 保存静态分析结果到 curr_ast 对象中
            # curr_ast.static_result = visitor2.analyzed_functions

            logger.info("AST line_node length = {}".format(len(curr_ast.line_node)))
            logger.info("Finished building AST")
            logger.debug(visitor2.analyzed_functions)
        return code_lis, curr_ast, visitor2
Пример #7
0
    def visit_ClassDef(self, node: ast.ClassDef):
        """AST ClassDef 类型结点访问入口
        更新 self.curr_class 标记
        """

        if self.curr_class is not None:
            # 适配多级类名
            classes = self.curr_class.split('.')
            classes.append(node.name)
            self.curr_class = '.'.join(classes)
            logger.debug("Get into class [{}]".format(self.curr_class))

            self.generic_visit(node)

            classes.pop()
            self.curr_class = '.'.join(classes)
        else:
            self.curr_class = node.name
            self.generic_visit(node)
            self.curr_class = None
Пример #8
0
    def visit_FunctionDef(self, node: ast.FunctionDef):
        """AST FunctionDef 类型结点访问入口
        扫描输入参数
        """
        curr_func = self.get_path_name_by_node([False], node)
        # -1 += 1

        # 清空结果
        self.var_input_maps.append(dict())
        self.str_input_maps.append(dict())

        # 扫描输入参数
        args = node.args.args
        for arg in args:
            # 忽略类方法中的 self 参数
            if self.curr_class is not None and arg.arg == 'self':
                continue

            # 将当前参数作为关注对象,参数名称作为首个字符串类型敏感值
            self.var_input_maps[-1][arg.arg] = set([arg.arg])
            self.str_input_maps[-1][arg.arg] = set([arg.arg])

        self.generic_visit(node)

        logger.debug("Current function is <{}>, with params {}".format(
            curr_func, list(self.var_input_maps[-1].keys())))
        logger.debug("Current function var_input_map: {}".format(
            self.var_input_maps[-1]))
        logger.debug("Current function str_input_map: {}".format(
            self.str_input_maps[-1]))

        # 记录函数扫描结果
        self.analyzed_functions[curr_func] = {
            # "name": curr_func,
            "input_names":
            list(self.var_input_maps[-1].keys()),
            "var_input":
            dict((key, list(val))
                 for key, val in self.var_input_maps[-1].items()),
            "str_input":
            dict((key, list(val))
                 for key, val in self.str_input_maps[-1].items()),
        }
        self.var_input_maps.pop()
        self.str_input_maps.pop()
Пример #9
0
    def dispatch_line(self, frame: FrameType):
        """代码行跟踪处理

        """
        logger.debug('{}:{}\t{}\t{}'.format(
            os.path.basename(self.curr_filename),
            frame.f_lineno,
            self.curr_ast.line_node[frame.f_lineno] if len(self.curr_ast.line_node) > 0 else '',
            self.curr_code[frame.f_lineno-1])
        )

        # 监视变量值
        for var_name in self.curr_watch_vars:
            name_split = var_name.split('.')
            try:
                obj = frame.f_locals[name_split[0]]
                val = eval("obj.{}".format('.'.join(name_split[1:]))) if len(name_split) > 1 else obj
                if isinstance(val, str) and \
                    val not in self.data_pool and \
                    val not in self.used_data_pool and \
                    len(val) < self.max_len:
                        logger.info("New str DETECTED during running: ({}:{}) {} = {}".format(
                            self.curr_filename, self.curr_lineno, var_name, val
                        ))
                        self.data_pool.append(val)
            except (KeyError, NameError, AttributeError):
                pass
        
        # 获取当前完整函数名称
        func_name = self.get_curr_func_name()
        if func_name:
            logger.debug("FULL_FUNC_NAME: {}".format(func_name))
            var_inputs: dict = self.curr_sde_visitor.analyzed_functions.get(func_name).get('var_input')
            assert var_inputs is not None

            # 清空监视变量
            self.curr_watch_vars.clear()
            for _, vars in var_inputs.items():
                self.curr_watch_vars.update(vars)
            logger.debug("NEXT_WATCH_VARS: {}".format(self.curr_watch_vars))

        
        return self.trace_dispatch
Пример #10
0
    def trace_dispatch(self, frame: FrameType, event: str, arg):
        """Dispatch a trace function based on the event.

        Possible events:
            - line: A new line of code is going to be executed.
            - call: A function is about to be called or another code block is entered.
            - return: A function or other code block is about to return.
            - exception: An exception has occured.
            - c_call: A C function is about to be called.
            - c_return: A C function has returned.
            - c_exception: A C function has raised an exception.
        """

        if self.quitting:
            return
        
        # logger.debug(frame.f_code.co_filename)

        # 忽略不在 self.entry_dir 文件夹中的文件
        if self.entry_dir not in frame.f_code.co_filename:
            return self.trace_dispatch
        
        is_white = False
        for fn in self.white_fns:
            if fn in frame.f_code.co_filename:
                is_white = True
                break
        
        if not is_white:
            for fn in self.black_fns:
                if fn in frame.f_code.co_filename:
                    return self.trace_dispatch
        
        if self.tracer_cov:
            # 执行覆盖率统计 tracer
            res = self.tracer_cov(frame, event, arg)
            level = 0
            max_level = 100
            sys.settrace(self.trace_dispatch)
            while res is not None and res != self.tracer_cov:
                level += 1
                assert level < max_level
                res = res(frame, event, arg)
                sys.settrace(self.trace_dispatch)

        # 更新当前代码行号
        self.curr_lineno = frame.f_lineno
        
        # 若当前源代码文件与上一个不同,则更新 self.curr_filename
        if frame.f_code.co_filename != self.curr_filename:
            self.curr_filename = frame.f_code.co_filename
            self.curr_code, self.curr_ast, self.curr_sde_visitor = self.load_code(self.curr_filename)
            logger.debug("Switch to source code file [{}]".format(self.curr_filename))

        # 根据 event 类型分发处理
        if event == 'line':
            return self.dispatch_line(frame)
        
        if event == 'return' or event == 'call':
            # 函数调用和返回时清空监视变量
            self.curr_watch_vars.clear()

        return self.trace_dispatch