def _save_scan_config(self): """ 存储当前扫描目标配置 """ config_model = ConfigModel( table_prefix="", use_async=True, create_table=False, multiplexing_conn=True) host_port = self.target_host + "_" + str(self.target_port) config_model.update(host_port, json.dumps(self.scan_config))
def _incremental_update_config(self, host_port, config): """ 增量更新扫描的运行时配置 Paramerters: host_port - str, 目标主机host_port config - dict, 更新的config """ config_model = ConfigModel(table_prefix="", use_async=True, create_table=True, multiplexing_conn=True) if host_port not in self._config_cache: origin_config_json = config_model.get(host_port) if origin_config_json is None: origin_config = self._default_config else: origin_config = json.loads(origin_config_json) else: origin_config = self._config_cache[host_port] version = origin_config["version"] if "scan_plugin_status" in config: for plugin_name in config["scan_plugin_status"]: origin_config["scan_plugin_status"][plugin_name][ "enable"] = config["scan_plugin_status"][plugin_name][ "enable"] if "scan_rate" in config: for key in config["scan_rate"]: if config["scan_rate"][key] >= 0: origin_config["scan_rate"][key] = config["scan_rate"][key] if origin_config["scan_rate"][ "min_request_interval"] > origin_config["scan_rate"][ "max_request_interval"]: origin_config["scan_rate"][ "max_request_interval"] = origin_config["scan_rate"][ "min_request_interval"] if "white_url_reg" in config: origin_config["white_url_reg"] = config["white_url_reg"] if "scan_proxy" in config: origin_config["scan_proxy"] = config["scan_proxy"] # 更新db、cache、和共享内存中的配置version origin_config["version"] = version + 1 config_model.update(host_port, json.dumps(origin_config)) self._config_cache[host_port] = origin_config # 更新速率控制 scanner_id = self._scannner_info.get_scanner_id(host_port) if scanner_id is not None: self._set_boundary_value(scanner_id, origin_config["scan_rate"]) Communicator().set_value("config_version", origin_config["version"], "Scanner_" + str(scanner_id))
def _incremental_update_config(self, host_port, config): """ 增量更新扫描的运行时配置 Paramerters: host_port - str, 目标主机host_port config - dict, 更新的config """ config_model = ConfigModel( table_prefix="", use_async=True, create_table=True, multiplexing_conn=True) origin_config_json = config_model.get(host_port) if origin_config_json is None: origin_config_json = config_model.get("default") origin_config = json.loads(origin_config_json) version = origin_config["version"] if "scan_plugin_status" in config: for plugin_name in config["scan_plugin_status"]: origin_config["scan_plugin_status"][plugin_name]["enable"] = config["scan_plugin_status"][plugin_name]["enable"] if "scan_rate" in config: for key in config["scan_rate"]: if config["scan_rate"][key] >= 0: origin_config["scan_rate"][key] = config["scan_rate"][key] if origin_config["scan_rate"]["min_request_interval"] > origin_config["scan_rate"]["max_request_interval"]: origin_config["scan_rate"]["max_request_interval"] = origin_config["scan_rate"]["min_request_interval"] if "white_url_reg" in config: origin_config["white_url_reg"] = config["white_url_reg"] if "scan_proxy" in config: origin_config["scan_proxy"] = config["scan_proxy"] origin_config["version"] = version + 1 config_model.update(host_port, json.dumps(origin_config)) for scanner_id in range(len(self.scanner_list)): if self.scanner_list[scanner_id] is not None: running_host_port = self.scanner_list[scanner_id]["host"] + "_" + str( self.scanner_list[scanner_id]["port"]) if host_port == running_host_port: self.set_boundary_value( scanner_id, origin_config["scan_rate"]) Communicator().set_value("config_version", origin_config["version"], "Scanner_" + str(scanner_id)) break
class ScannerManager(object): def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(ScannerManager, cls).__new__(cls) return cls.instance def init_manager(self, scanner_schedulers): """ 初始化 Parameters: scanner_schedulers - 所有扫描任务调度类组成的dict, key为扫描任务的Module_name """ self.max_scanner = Config().get_config("scanner.max_module_instance") self.scanner_schedulers = scanner_schedulers self.scanner_list = [None] * self.max_scanner self._init_config() def _init_config(self): """ 初始化扫描配置 """ self.config_model = ConfigModel(table_prefix="", use_async=True, create_table=True, multiplexing_conn=False) self.plugin_loaded = {} plugin_path = Communicator().get_main_path() + "/plugin/scanner" plugin_import_path = "plugin.scanner" # 需要加载插件, 提供一个假的report_model Communicator().set_internal_shared("report_model", None) Communicator().set_internal_shared("failed_task_set", None) plugin_names = [] for file_name in os.listdir(plugin_path): if os.path.isfile(plugin_path + os.sep + file_name) and file_name.endswith(".py"): plugin_names.append(file_name[:-3]) for plugin_name in plugin_names: try: plugin_module = __import__(plugin_import_path, fromlist=[plugin_name]) except Exception as e: Logger().critical( "Error in load plugin: {}".format(plugin_name), exc_info=e) exit(1) else: plugin_instance = getattr(plugin_module, plugin_name).ScanPlugin() if isinstance(plugin_instance, scan_plugin_base.ScanPluginBase): self.plugin_loaded[plugin_name] = plugin_instance Logger().debug( "scanner plugin: {} preload success!".format( plugin_name)) else: Logger().critical( "Detect scanner plugin {} not inherit class ScanPluginBase!" .format(plugin_name)) exit(1) plugin_status = {} for plugin_name in self.plugin_loaded: plugin_status[plugin_name] = { "enable": True, "show_name": self.plugin_loaded[plugin_name].plugin_info["show_name"], "description": self.plugin_loaded[plugin_name].plugin_info["description"] } default_config = { "scan_plugin_status": plugin_status, "scan_rate": { "max_concurrent_request": Config().get_config("scanner.max_concurrent_request"), "max_request_interval": Config().get_config("scanner.max_request_interval"), "min_request_interval": Config().get_config("scanner.min_request_interval") }, "white_url_reg": "", "version": 0 } # 插件列表有更新时,删除当前缓存的所有配置 origin_default_config = self.config_model.get("default") if origin_default_config is not None: origin_default_config = json.loads(origin_default_config) if len(origin_default_config["scan_plugin_status"]) != len( default_config["scan_plugin_status"]): self.config_model.delete("all") else: for plugin_names in origin_default_config[ "scan_plugin_status"]: if plugin_names not in default_config[ "scan_plugin_status"]: self.config_model.delete("all") break self.config_model.update("default", json.dumps(default_config)) def _check_alive(self): """ 刷新当前扫描任务存活状态 """ reset_list = [] for scanner_id in range(self.max_scanner): if self.scanner_list[scanner_id] is not None: pid = Communicator().get_value("pid", "Scanner_" + str(scanner_id)) if pid == 0: reset_list.append(scanner_id) for scanner_id in reset_list: self.scanner_list[scanner_id] = None def _incremental_update_config(self, host_port, config): """ 增量更新扫描的运行时配置 Paramerters: host_port - str, 目标主机host_port config - dict, 更新的config """ origin_config_json = self.config_model.get(host_port) if origin_config_json is None: origin_config_json = self.config_model.get("default") origin_config = json.loads(origin_config_json) version = origin_config["version"] if "scan_plugin_status" in config: for plugin_name in config["scan_plugin_status"]: origin_config["scan_plugin_status"][plugin_name][ "enable"] = config["scan_plugin_status"][plugin_name][ "enable"] if "scan_rate" in config: for key in config["scan_rate"]: origin_config["scan_rate"][key] = config["scan_rate"][key] if "white_url_reg" in config: origin_config["white_url_reg"] = config["white_url_reg"] origin_config["version"] = version + 1 self.config_model.update(host_port, json.dumps(origin_config)) for scanner_id in range(len(self.scanner_list)): if self.scanner_list[scanner_id] is not None: running_host_port = self.scanner_list[scanner_id][ "host"] + "_" + str(self.scanner_list[scanner_id]["port"]) if host_port == running_host_port: self.set_boundary_value(scanner_id, origin_config["scan_rate"]) Communicator().set_value("config_version", origin_config["version"], "Scanner_" + str(scanner_id)) break def new_scanner(self, module_params): """ 创建一个新的扫描任务 Parameters: module_params - dict, 结构为{ "host":str, 目标主机, "port":int, 目标端口 "config": dict, 配置信息 } Raises: exceptions.MaximumScannerExceede - 扫描任务数量到达上限,引发此异常 exceptions.TargetIsScanning - 指定目标正在被其他任务扫描,引发此异常 """ self._check_alive() idle_scanner = None for scanner_id in range(self.max_scanner): if self.scanner_list[scanner_id] is None: idle_scanner = scanner_id break if idle_scanner is None: raise exceptions.MaximumScannerExceede for item in self.scanner_list: if item is not None: if item["host"] == module_params["host"] and item[ "port"] == module_params["port"]: raise exceptions.TargetIsScanning host_port = module_params["host"] + "_" + str(module_params["port"]) self._incremental_update_config(host_port, {}) scanner_process_kwargs = { "module_cls": modules.Scanner, "instance_id": idle_scanner, "module_params": { "host": module_params["host"], "port": module_params["port"] } } Communicator().reset_all_value("Scanner_" + str(idle_scanner)) pid = ForkProxy().fork(scanner_process_kwargs) new_scanner_info = { "pid": pid, "host": module_params["host"], "port": module_params["port"], "cancel": 0, "pause": 0 } Communicator().set_value("pid", pid, "Scanner_" + str(idle_scanner)) self.scanner_list[idle_scanner] = new_scanner_info def get_config(self, module_params): """ 获取扫描目标的配置 Parameters: module_params - dict, 结构为{ "host":str, 目标主机, "port":int, 目标端口 } """ host_port = module_params["host"] + "_" + str(module_params["port"]) config_json = self.config_model.get(host_port) if config_json is None: config_json = self.config_model.get("default") return json.loads(config_json) def mod_config(self, module_params): """ 修改扫描目标的配置 Parameters: module_params - dict, 结构为{ "host":str, 目标主机, "port":int, 目标端口 "config": dict, 配置信息 } """ host_port = module_params["host"] + "_" + str(module_params["port"]) self._incremental_update_config(host_port, module_params["config"]) def pause_scanner(self, scanner_id): """ 将一个扫描进程的共享内存的pause设置为1 Parameters: scanner_id - int, 目标扫描进程的id """ self._check_alive() try: assert self.scanner_list[scanner_id] is not None assert self.scanner_list[scanner_id]["pid"] != 0 except: raise exceptions.InvalidScannerId module_name = "Scanner_" + str(scanner_id) self.scanner_list[scanner_id]["pause"] = 1 Communicator().set_value("pause", 1, module_name) def resume_scanner(self, scanner_id): """ 将一个扫描进程的共享内存的pause设置为0 Parameters: scanner_id - int, 目标扫描进程的id """ self._check_alive() try: assert self.scanner_list[scanner_id] is not None assert self.scanner_list[scanner_id]["pid"] != 0 except: raise exceptions.InvalidScannerId module_name = "Scanner_" + str(scanner_id) self.scanner_list[scanner_id]["pause"] = 0 Communicator().set_value("pause", 0, module_name) def cancel_scanner(self, scanner_id): """ 将一个扫描进程的共享内存的cancel设置为1 Parameters: scanner_id - int, 目标扫描进程的id """ self._check_alive() try: assert self.scanner_list[scanner_id] is not None assert self.scanner_list[scanner_id]["pid"] != 0 except: raise exceptions.InvalidScannerId module_name = "Scanner_" + str(scanner_id) self.scanner_list[scanner_id]["cancel"] = 1 Communicator().set_value("cancel", 1, module_name) def kill_scanner(self, scanner_id): """ 强制结束一个扫描进程进程 Parameters: scanner_id - int类型, 要结束的扫描进程的id Returns: 成功结束返回True,否则返回false """ self._check_alive() if (self.scanner_list[scanner_id] is None or self.scanner_list[scanner_id]["pid"] == 0): raise exceptions.InvalidScannerId pid = self.scanner_list[scanner_id]["pid"] try: proc = psutil.Process(pid) except psutil.NoSuchProcess: Logger().warning("Try to kill not running scanner!") raise exceptions.InvalidScannerId proc.terminate() try: proc.wait(timeout=5) except psutil.TimeoutExpired: proc.kill() proc.wait(timeout=5) if proc.is_running(): return False else: self.scanner_list[scanner_id] = None module_name = "Scanner_" + str(scanner_id) Communicator().set_value("pid", 0, module_name) return True def is_scanning(self, host, port): self._check_alive() for scanner in self.scanner_list: if scanner is not None and scanner["host"] == host and scanner[ "port"] == port: return True return False async def get_running_info(self): """ 获取当前扫描任务信息 Returns: dict, 结构: { "0":{ "pid": 64067, // 扫描进程pid "host": "127.0.0.1", // 扫描的目标主机 "port": 8005, // 扫描的目标端口 "auth_plugin": "default", // 使用的授权插件(功能尚未使用) "scan_plugin_list": ["plugin1", "plugin2"] // 为空时,默认使用所有插件 "rasp_result_timeout": 0, // 获取rasp-agent结果超时数量 "waiting_rasp_request": 0, // 等待中的rasp-agent结果数量 "dropped_rasp_result": 0, // 收到的无效rasp-agent结果数量 "send_request": 0, // 已发送测试请求 "failed_request": 0, // 发生错误的测试请求 "cpu": "0.0%", // cpu占用 "mem": "10.51 M", // 内存占用 "total": 5, // 当前url总数 "scanned": 2, // 扫描的url数量 "concurrent_request": 10, // 当前并发数 "request_interval": 0, // 当前请求间隔 }, "1":{ ... }, } Raises: exceptions.DatabaseError - 数据库错误引发此异常 """ self._check_alive() result = {} for scanner_id in range(self.max_scanner): if self.scanner_list[scanner_id] is not None: result[scanner_id] = copy.deepcopy( self.scanner_list[scanner_id]) for module_id in result: module_name = "Scanner_" + str(module_id) runtime_info = RuntimeInfo().get_latest_info()[module_name] for key in runtime_info: result[module_id][key] = runtime_info[key] try: scheduler = self.scanner_schedulers[module_name] except KeyError: raise exceptions.InvalidScannerId result[module_id]["concurrent_request"] = Communicator().get_value( "max_concurrent_request", module_name) result[module_id]["request_interval"] = Communicator().get_value( "max_concurrent_request", module_name) table_prefix = result[module_id]["host"] + "_" + str( result[module_id]["port"]) total, scanned = await NewRequestModel( table_prefix, multiplexing_conn=True).get_scan_count() result[module_id]["total"] = total result[module_id]["scanned"] = scanned if "pause" in result[module_id]: del result[module_id]["pause"] if "cancel" in result[module_id]: del result[module_id]["cancel"] return result def clean_target(self, host, port, url_only=False): """ 清空目标对应的数据库,同时重置预处理lru Parameters: host - str, 目标host port - int, 目标port url_only - bool, 是否仅清空url Raises: exceptions.DatabaseError - 数据库出错时引发此异常 """ table_prefix = host + "_" + str(port) if url_only: NewRequestModel(table_prefix, multiplexing_conn=True).truncate_table() else: NewRequestModel(table_prefix, multiplexing_conn=True).drop_table() ReportModel(table_prefix, multiplexing_conn=True).drop_table() self.config_model.delete(table_prefix) Communicator().set_clean_lru([table_prefix]) async def get_all_target(self): """ 获取数据库中存在的所有目标主机的列表 Returns: list, item为dict,格式为: 正在扫描的item: { "id": 1, // 扫描任务id "pid": 64067, // 扫描进程pid "host": "127.0.0.1", // 扫描的目标主机 "port": 8005, // 扫描的目标端口 "cancel": 0, // 是否正在取消 "pause": 0, // 是否被暂停 "cpu": "0.0%", // cpu占用 "mem": "10.51 M", // 内存占用 "total": 5, // 当前url总数 "scanned": 2, // 扫描的url数量 "concurrent_request": 10, // 当前并发数 "request_interval": 0, // 当前请求间隔 "config": {...}, // 配置信息 "last_time": 1563182956 // 最近一次获取到新url的时间 } 未在扫描的item: { "host": "127.0.0.1", // 扫描的目标主机 "port": 8005, // 扫描的目标端口 "total": 5, // 当前url总数 "scanned": 2, // 扫描的url数量 "config": {...}, // 配置信息 "last_time": 1563182956 // 最近一次获取到新url的时间 } Raises: exceptions.DatabaseError - 数据库错误引发此异常 """ tables = BaseModel().get_tables() result = {} for table_name in tables: if table_name.endswith("_ResultList"): host_port = table_name[:-11] else: continue host_port_split = host_port.split("_") host = "".join(host_port_split[:-1]) port = host_port_split[-1] result[host_port] = {"host": host, "port": port} running_info = await self.get_running_info() for scanner_id in running_info: host_port = running_info[scanner_id]["host"] + "_" + str( running_info[scanner_id]["port"]) result[host_port] = running_info[scanner_id] result[host_port]["id"] = scanner_id result_list = [] for host_port in result: new_request_model = NewRequestModel(host_port, multiplexing_conn=True) result[host_port][ "last_time"] = await new_request_model.get_last_time() if result[host_port].get("id", None) is None: total, scanned = await new_request_model.get_scan_count() result[host_port]["total"] = total result[host_port]["scanned"] = scanned target_config = self.config_model.get(host_port) if target_config is None: target_config = self.config_model.get("default") result[host_port]["config"] = json.loads(target_config) result_list.append(result[host_port]) result_list.sort(key=(lambda k: k["last_time"]), reverse=True) return result_list async def get_report(self, host_port, page, perpage): """ 获取扫描结果 Parameters: host_port - str, 获取的目标主机的 host + "_" + str(port) 组成 page - int, 获取的页码 perpage - int, 每页条数 Returns: {"total":数据总条数, "data":[ RaspResult的json字符串, ...]} Raises: exceptions.DatabaseError - 数据库错误引发此异常 """ try: model = ReportModel(host_port, create_table=False, multiplexing_conn=True) except exceptions.TableNotExist: data = {"total": 0, "data": []} else: data = await model.get(page, perpage) return data def get_plugin_info(self, plugin_path, class_prefix): """ 获取指定类型插件的plugin_info Parameters: plugin_path - str, 插件目录 Returns: list, 每个item为一个plugin_info dict """ result = [] plugin_names = [] plugin_import_path = plugin_path.replace(os.sep, ".") for file_name in os.listdir(plugin_path): if os.path.isfile(plugin_path + os.sep + file_name) and file_name.endswith(".py"): plugin_names.append(file_name[:-3]) for plugin_name in plugin_names: try: plugin_module = __import__(plugin_import_path, fromlist=[plugin_name]) except Exception as e: Logger().warning( "Error in import plugin: {}".format(plugin_name), exc_info=e) else: plugin_module = getattr(plugin_module, plugin_name) plugin_info = getattr(plugin_module, class_prefix + "Plugin").plugin_info result.append(plugin_info) return result def get_plugins(self): """ 获取插件列表 Returns: { "scan":[ {"name":plugin_name, "description":xxxx}, ...], "dedup":[ ... ] "auth: [ ... ] } """ main_path = Communicator().get_main_path() result = { # "Auth": main_path + "/plugin/authorizer", "Dedup": main_path + "/plugin/deduplicate", "Scan": main_path + "/plugin/scanner", } for key in result: result[key] = self.get_plugin_info(result[key], key) return result def set_boundary_value(self, scanner_id, boundary): """ 配置扫描速率范围 Parameters: scanner_id - int, 配置的scanner的id boundary - dict, 配置项, 格式 { "max_concurrent_request": 10, "max_request_interval": 1000, "min_request_interval: 0 } Raises: exceptions.InvalidScannerId - 目标id不存在引发此异常 """ module_name = "Scanner_" + str(scanner_id) try: scheduler = self.scanner_schedulers[module_name] except KeyError: raise exceptions.InvalidScannerId cr_max = boundary["max_concurrent_request"] ri_max = boundary["max_request_interval"] ri_min = boundary["min_request_interval"] scheduler.set_boundary_value(cr_max, ri_max, ri_min) def set_auto_start(self, auto_start): """ 设置自动启动扫描开关(请求首次接收时启动扫描) Parameters: auto_start - bool, 是否开启自启动扫描 """ if auto_start is True: Communicator().set_value("auto_start", 1, "Monitor") else: Communicator().set_value("auto_start", 0, "Monitor")
def __init__(self, scanner_info, scanner_schedulers): """ 初始化扫描配置 """ self._scanner_schedulers = scanner_schedulers self._scannner_info = scanner_info config_model = ConfigModel(table_prefix="", use_async=True, create_table=True, multiplexing_conn=True) self.plugin_loaded = {} plugin_path = Communicator().get_main_path() + "/plugin/scanner" plugin_import_path = "plugin.scanner" # 需要加载插件, 提供一些dummy对象 Communicator().set_internal_shared("report_model", None) Communicator().set_internal_shared("failed_task_set", None) plugin_names = [] for file_name in os.listdir(plugin_path): if os.path.isfile(plugin_path + os.sep + file_name) and file_name.endswith(".py"): plugin_names.append(file_name[:-3]) for plugin_name in plugin_names: try: plugin_module = __import__(plugin_import_path, fromlist=[plugin_name]) except Exception as e: Logger().critical( "Error in load plugin: {}".format(plugin_name), exc_info=e) sys.exit(1) else: plugin_instance = getattr(plugin_module, plugin_name).ScanPlugin() if isinstance(plugin_instance, scan_plugin_base.ScanPluginBase): self.plugin_loaded[plugin_name] = plugin_instance Logger().debug( "scanner plugin: {} preload success!".format( plugin_name)) else: Logger().critical( "Detect scanner plugin {} not inherit class ScanPluginBase!" .format(plugin_name)) sys.exit(1) plugin_status = {} for plugin_name in self.plugin_loaded: plugin_status[plugin_name] = { "enable": True, "show_name": self.plugin_loaded[plugin_name].plugin_info["show_name"], "description": self.plugin_loaded[plugin_name].plugin_info["description"] } default_config = { "scan_plugin_status": plugin_status, "scan_rate": { "max_concurrent_request": Config().get_config("scanner.max_concurrent_request"), "max_request_interval": Config().get_config("scanner.max_request_interval"), "min_request_interval": Config().get_config("scanner.min_request_interval") }, "white_url_reg": "", "scan_proxy": "", "version": 0 } # 插件列表有更新时,删除当前缓存的所有插件启用配置 origin_default_config = config_model.get("default") if origin_default_config is not None: origin_default_config = json.loads(origin_default_config) if len(origin_default_config["scan_plugin_status"]) != len( default_config["scan_plugin_status"]): config_model.delete("all") else: for plugin_names in origin_default_config[ "scan_plugin_status"]: if plugin_names not in default_config[ "scan_plugin_status"]: config_model.delete("all") break config_model.update("default", json.dumps(default_config)) self._default_config = default_config self._config_cache = {}
class Scanner(base.BaseModule): def __init__(self, **kwargs): """ 初始化 """ # kwargs 参数初始化 self.target_host = kwargs["host"] self.target_port = kwargs["port"] self._init_scan_config() # 用于记录失败请求并标记 self.failed_task_set = set() Communicator().set_internal_shared("failed_task_set", self.failed_task_set) self.module_id = Communicator().get_module_name().split("_")[-1] Communicator().set_value("max_concurrent_request", 1) Communicator().set_value("request_interval", Config().get_config("scanner.min_request_interval")) self._init_db() self._init_plugin() # 更新运行时配置 self._update_scan_config() def _init_scan_config(self): """ 获取缓存的扫描配置 """ self.config_model = ConfigModel(table_prefix="", use_async=True, create_table=False, multiplexing_conn=True) host_port = self.target_host + "_" + str(self.target_port) config = self.config_model.get(host_port) if config is None: raise exceptions.GetRuntimeConfigFail self.scan_config = json.loads(config) def _save_scan_config(self): """ 存储当前扫描目标配置 """ host_port = self.target_host + "_" + str(self.target_port) self.config_model.update(host_port, json.dumps(self.scan_config)) def _update_scan_config(self): """ 更新当前运行的扫描配置 """ host_port = self.target_host + "_" + str(self.target_port) self.scan_config = json.loads(self.config_model.get(host_port)) for plugin_name in self.scan_config["scan_plugin_status"]: self.plugin_loaded[plugin_name].set_enable(self.scan_config["scan_plugin_status"][plugin_name]["enable"]) self.plugin_loaded[plugin_name].set_white_url_reg(self.scan_config["white_url_reg"]) Logger().debug("Update scanner config to version {}, new config json is {}".format(self.scan_config["version"], json.dumps(self.scan_config))) def _init_plugin(self): """ 初始化扫描插件 """ self.plugin_loaded = {} plugin_import_path = "plugin.scanner" for plugin_name in self.scan_config["scan_plugin_status"].keys(): try: plugin_module = __import__( plugin_import_path, fromlist=[plugin_name]) except Exception as e: Logger().error("Error in load plugin: {}".format(plugin_name), exc_info=e) else: plugin_instance = getattr( plugin_module, plugin_name).ScanPlugin() if isinstance(plugin_instance, scan_plugin_base.ScanPluginBase): self.plugin_loaded[plugin_name] = plugin_instance Logger().debug( "scanner plugin: {} load success!".format(plugin_name)) else: Logger().warning("scanner plugin {} not inherit class ScanPluginBase!".format(plugin_name)) if len(self.plugin_loaded) == 0: Logger().error("No scanner plugin detected, scanner exit!") raise exceptions.NoPluginError def _init_db(self): """ 初始化数据库 """ model_prefix = self.target_host + "_" + str(self.target_port) self.new_scan_model = NewRequestModel(model_prefix) self.new_scan_model.reset_unscanned_item() report_model = ReportModel(model_prefix) Communicator().set_internal_shared("report_model", report_model) def _exit(self, signame, loop): loop.stop() def run(self): """ 模块主函数,启动协程 """ try: asyncio.run(self.async_run()) except RuntimeError: Logger().info("Scanner process has been killed!") except Exception as e: Logger().error("Scanner exit with unknow error!", exc_info=e) async def async_run(self): """ 协程主函数 """ # 注册信号处理 loop = asyncio.get_running_loop() for signame in {'SIGINT', 'SIGTERM'}: loop.add_signal_handler( getattr(signal, signame), functools.partial(self._exit, signame, loop)) # 初始化context await audit_tools.context.Context().async_init() # 启动插件 plugin_tasks = [] for plugin_name in self.plugin_loaded: plugin_tasks.append(asyncio.create_task( self.plugin_loaded[plugin_name].async_run())) # 启动获取扫描结果队列的协程 task_fetch_rasp_result = asyncio.create_task(self._fetch_from_queue()) # 执行获取新扫描任务 await self._fetch_new_scan() # 结束所有协程任务,reset共享内存 task_fetch_rasp_result.cancel() await asyncio.wait({task_fetch_rasp_result}) for task in plugin_tasks: task.cancel() await asyncio.wait(set(plugin_tasks), return_when=asyncio.ALL_COMPLETED) Communicator().reset_all_value() async def _fetch_from_queue(self): """ 获取扫描请求的RaspResult, 并分发给扫描插件 """ queue_name = "rasp_result_queue_" + self.module_id sleep_interval = 0.1 continuously_sleep = 0 Logger().debug("Fetch task is running, use queue: " + queue_name) while True: if Communicator().get_value("config_version") > self.scan_config["version"]: self._update_scan_config() try: data = Communicator().get_data_nowait(queue_name) Logger().debug("From rasp_result_queue got data: " + str(data)) result_receiver.RaspResultReceiver().add_result(data) Logger().debug("Send data to rasp_result receiver: {}".format( data.get_request_id())) continuously_sleep = 0 except exceptions.QueueEmpty: if continuously_sleep < 10: continuously_sleep += 1 await asyncio.sleep(sleep_interval * continuously_sleep) async def _fetch_new_scan(self): """ 获取非扫描请求(新扫描任务),并分发给插件,= """ # 扫描插件任务队列最大值 scan_queue_max = 300 # 已扫描的任务数量 self.scan_num = 0 # 扫描队列数量 self.scan_queue_remaining = 0 # 下次获取任务数量 self.fetch_count = 20 # 待标记的已扫描的最大id self.mark_id = 0 while True: if Communicator().get_value("cancel") == 0: try: await self._fetch_task_from_db() except exceptions.DatabaseError as e: Logger().error("Database error occured when fetch scan task.", exc_info=e) except asyncio.CancelledError as e: raise e except Exception as e: Logger().error("Unexpected error occured when fetch scan task.", exc_info=e) if self.scan_queue_remaining == 0: continue elif self.scan_queue_remaining == 0: break await self._check_scan_progress() # 调整每次获取的扫描任务数 if self.scan_queue_remaining + self.fetch_count > scan_queue_max: self.fetch_count = scan_queue_max - self.scan_queue_remaining elif self.fetch_count < 5: self.fetch_count = 5 async def _fetch_task_from_db(self): """ 从数据库中获取当前扫描目标的非扫描请求(新扫描任务) """ await self.new_scan_model.mark_result(self.mark_id, list(self.failed_task_set)) self.failed_task_set.clear() sleep_interval = 1 continuously_sleep = 0 while True: if Communicator().get_value("cancel") != 0: break data_list = await self.new_scan_model.get_new_scan(self.fetch_count) data_count = len(data_list) if data_count > 0: for item in data_list: for plugin_name in self.plugin_loaded: # item 格式: {"id": id, "data":rasp_result_json} self.plugin_loaded[plugin_name].add_task(item) Logger().debug("Send task with id: {} to plugins.".format(item["id"])) self.scan_queue_remaining += data_count return else: if continuously_sleep < 10: continuously_sleep += 1 await asyncio.sleep(sleep_interval * continuously_sleep) async def _check_scan_progress(self): """ 监测扫描进度,给出下次获取的任务量 """ sleep_interval = 1 sleep_count = 0 while True: await asyncio.sleep(sleep_interval) sleep_count += 1 scan_num_list = [] scan_id_list = [] for plugin_name in self.plugin_loaded: plugin_ins = self.plugin_loaded[plugin_name] plugin_scan_num, plugin_last_id = plugin_ins.get_scan_progress() scan_num_list.append(plugin_scan_num) scan_id_list.append(plugin_last_id) plugin_scan_min_num = min(scan_num_list) plugin_scan_min_id = min(scan_id_list) finish_count = plugin_scan_min_num - self.scan_num if sleep_count > 20: # 20个sleep内未扫描完成,每次最大获取任务量减半 self.scan_queue_remaining -= finish_count self.scan_num = plugin_scan_min_num sleep_count = 0 elif sleep_count > 10: # 10-20个sleep内完成一半以上,每次最大获取任务量不变 if self.scan_queue_remaining < finish_count * 2: self.fetch_count = finish_count break elif self.scan_queue_remaining == finish_count: # 10个sleep内完成,每次最大获取任务量加倍 self.fetch_count = self.scan_queue_remaining * 2 break self.scan_queue_remaining -= finish_count self.scan_num = plugin_scan_min_num self.mark_id = plugin_scan_min_id