Пример #1
0
    def start(self):
        if not self.db_oper.is_enabled():
            return

        repo_list = self.db_oper.get_repo_list()
        if repo_list is None:
            self.db_oper.close_db()
            return

        thread_pool = ThreadPool(self.scan_virus, self.settings.threads)
        thread_pool.start()

        for row in repo_list:
            repo_id, head_commit_id, scan_commit_id = row

            if head_commit_id == scan_commit_id:
                logger.debug('No change occur for repo %.8s, skip virus scan.',
                             repo_id)
                continue

            thread_pool.put_task(
                ScanTask(repo_id, head_commit_id, scan_commit_id))

        thread_pool.join()

        self.db_oper.close_db()
Пример #2
0
class ContentScan(object):
    def __init__(self):
        self.thread_pool = ThreadPool(self.diff_and_scan_content, appconfig.thread_num)
        self.thread_pool.start()

    def start(self):
        try:
            self.do_scan_task()
        except Exception as e:
            logging.warning('Error: %s', e)

    def do_scan_task(self):
        logging.info("Start scan task..")
        time_start = time.time()

        dt = datetime.utcnow()
        dt_str = dt.strftime('%Y-%m-%d %H:%M:%S')
        self.dt = datetime.strptime(dt_str,'%Y-%m-%d %H:%M:%S')

        edb_session = appconfig.session_cls()
        seafdb_session = appconfig.seaf_session_cls()

        # Get repo list from seafile-db
        Branch = SeafBase.classes.Branch
        VirtualRepo= SeafBase.classes.VirtualRepo
        q = seafdb_session.query(Branch.repo_id, Branch.commit_id)
        q = q.outerjoin(VirtualRepo, Branch.repo_id==VirtualRepo.repo_id)
        q = q.filter(VirtualRepo.repo_id == None)
        results = q.all()
        for row in results:
            repo_id = row.repo_id
            new_commit_id = row.commit_id
            last_commit_id = None
            q = edb_session.query(ContentScanRecord.commit_id)
            q = q.filter(ContentScanRecord.repo_id==repo_id)
            result = q.first()
            if result:
                last_commit_id = result[0]

            self.put_task(repo_id, last_commit_id, new_commit_id)

        # Remove deleted repo's record after all threads finished
        self.thread_pool.join()
        q = edb_session.query(ContentScanRecord)
        q = q.filter(ContentScanRecord.timestamp != self.dt)
        q.delete()
        q = edb_session.query(ContentScanResult)
        subqurey = edb_session.query(ContentScanRecord.repo_id)
        q = q.filter(ContentScanResult.repo_id.notin_(subqurey))
        # need fetch subqurey
        q.delete(synchronize_session='fetch')
        edb_session.commit()

        edb_session.close()
        seafdb_session.close()
        logging.info('Finish scan task, total time: %s seconds\n', str(time.time() - time_start))

        self.thread_pool.join(stop=True)

    def diff_and_scan_content(self, task, client):
        repo_id = task.repo_id
        last_commit_id = task.last_commit_id
        new_commit_id = task.new_commit_id
        edb_session = appconfig.session_cls()

        # repo not changed, update timestamp
        if last_commit_id == new_commit_id:
            q = edb_session.query(ContentScanRecord)
            q = q.filter(ContentScanRecord.repo_id==repo_id,
                         ContentScanRecord.commit_id==last_commit_id)
            q.update({"timestamp": self.dt})
            edb_session.commit()
            edb_session.close()
            return

        # diff
        version = 1
        new_commit = commit_mgr.load_commit(repo_id, version, new_commit_id)
        if new_commit is None:
            version = 0
            new_commit = commit_mgr.load_commit(repo_id, version, new_commit_id)
        if not new_commit:
            logging.warning('Failed to load commit %s/%s', repo_id, new_commit_id)
            edb_session.close()
            return
        last_commit = None
        if last_commit_id:
            last_commit = commit_mgr.load_commit(repo_id, version, last_commit_id)
            if not last_commit:
                logging.warning('Failed to load commit %s/%s', repo_id, last_commit_id)
                edb_session.close()
                return
        new_root_id = new_commit.root_id
        last_root_id = last_commit.root_id if last_commit else ZERO_OBJ_ID

        differ = CommitDiffer(repo_id, version, last_root_id, new_root_id,
                              True, False)
        added_files, deleted_files, added_dirs, deleted_dirs, modified_files,\
        renamed_files, moved_files, renamed_dirs, moved_dirs = differ.diff_to_unicode()

        # Handle renamed, moved and deleted files.
        q = edb_session.query(ContentScanResult).filter(ContentScanResult.repo_id==repo_id)
        results = q.all()
        if results:
            path_pairs_to_rename = []
            paths_to_delete = []
            # renamed dirs
            for r_dir in renamed_dirs:
                r_path = r_dir.path + '/'
                l = len(r_path)
                for row in results:
                    if r_path == row.path[:l]:
                        new_path = r_dir.new_path + '/' + row.path[l:]
                        path_pairs_to_rename.append((row.path, new_path))
            # moved dirs
            for m_dir in moved_dirs:
                m_path = m_dir.path + '/'
                l = len(m_path)
                for row in results:
                    if m_path == row.path[:l]:
                        new_path = m_dir.new_path + '/' + row.path[l:]
                        path_pairs_to_rename.append((row.path, new_path))
            # renamed files
            for r_file in renamed_files:
                r_path = r_file.path
                for row in results:
                    if r_path == row.path:
                        new_path = r_file.new_path
                        path_pairs_to_rename.append((row.path, new_path))
            # moved files
            for m_file in moved_files:
                m_path = m_file.path
                for row in results:
                    if m_path == row.path:
                        new_path = m_file.new_path
                        path_pairs_to_rename.append((row.path, new_path))

            for old_path, new_path in path_pairs_to_rename:
                q = edb_session.query(ContentScanResult)
                q = q.filter(ContentScanResult.repo_id==repo_id, ContentScanResult.path==old_path)
                q = q.update({"path": new_path})

            # deleted files
            for d_file in deleted_files:
                d_path = d_file.path
                for row in results:
                    if d_path == row.path:
                        paths_to_delete.append(row.path)
            # We will scan modified_files and re-record later,
            # so delete previous records now
            for m_file in modified_files:
                m_path = m_file.path
                for row in results:
                    if m_path == row.path:
                        paths_to_delete.append(row.path)

            for path in paths_to_delete:
                q = edb_session.query(ContentScanResult)
                q = q.filter(ContentScanResult.repo_id==repo_id, ContentScanResult.path==path)
                q.delete()

            edb_session.commit()

        # scan added_files and modified_files by third-party API.
        files_to_scan = []
        files_to_scan.extend(added_files)
        files_to_scan.extend(modified_files)
        a_count = 0
        scan_results = []
        for f in files_to_scan:
            if not self.should_scan_file (f.path, f.size):
                continue
            seafile_obj = fs_mgr.load_seafile(repo_id, 1, f.obj_id)
            content = seafile_obj.get_content()
            if not content:
                continue
            result = client.scan(content)
            if result and isinstance(result, dict):
                item = {"path": f.path, "detail": result}
                scan_results.append(item)
            else:
                logging.warning('Failed to scan %s:%s', repo_id, f.path)

        for item in scan_results:
            detail = json.dumps(item["detail"])
            new_record = ContentScanResult(repo_id, item["path"], appconfig.platform, detail)
            edb_session.add(new_record)
            a_count += 1
        if a_count >= 1:
            logging.info('Found %d new illegal files.', a_count)

        # Update ContentScanRecord
        if last_commit_id:
            q = edb_session.query(ContentScanRecord).filter(ContentScanRecord.repo_id==repo_id)
            q.update({"commit_id": new_commit_id, "timestamp": self.dt})
        else:
            new_record = ContentScanRecord(repo_id, new_commit_id, self.dt)
            edb_session.add(new_record)

        edb_session.commit()
        edb_session.close()

    def put_task(self, repo_id, last_commit_id, new_commit_id):
        task = ScanTask(repo_id, last_commit_id, new_commit_id)
        self.thread_pool.put_task(task)

    def should_scan_file(self, fpath, fsize):
        if fsize > appconfig.size_limit:
            return False

        filename, suffix = splitext(fpath)
        if suffix[1:] not in appconfig.suffix_list:
            return False

        return True
class Crawler(object):

    def __init__(self, args):
        self.thread_num = args.thread_num
        self.output = args.output
        if not os.path.exists(self.output):
            os.mkdir(self.output)

        self.domain_pattern = re.compile(
            r"^([0-9a-zA-Z][0-9a-zA-Z-]{0,62}\.)+([0-9a-zA-Z][0-9a-zA-Z-]{0,62})\.?$")

    def _init(self):
        # 线程池,指定线程数
        self.thread_pool = ThreadPool(self.thread_num)
        self.depth = 2
        # 标注初始爬虫深度,从1开始
        self.current_depth = 1
        # 已访问的链接
        self.visited_hrefs = set()
        # 待访问的链接
        self.unvisited_hrefs = deque()
        # 标记爬虫是否开始执行任务
        self.is_crawling = False
        self.resource_details = ResourceDetailCollection()

    def _format_url(self, raw_value):
        raw_value_str = raw_value.strip().strip('\n')
        if len(raw_value_str) <= 0:
            return ''
        if not self.domain_pattern.match(raw_value_str):
            return ''
        if not raw_value_str.startswith('http'):
            value = 'http://' + raw_value_str
        else:
            value = raw_value_str
        return value

    def crawl(self, url):
        self._init()
        formatted_url = self._format_url(url)
        self.resource_details.set_main_frame_url(formatted_url)
        self.unvisited_hrefs.append(formatted_url)
        print '\nStart Crawling url %s\n' % formatted_url
        self.is_crawling = True
        self.thread_pool.start_threads()
        while self.current_depth < self.depth + 1:
            # 分配任务,线程池并发下载当前深度的所有页面(该操作不阻塞)
            self._assigin_current_depth_tasks()
            # 等待当前线程池完成所有任务,当池内的所有任务完成时,即代表爬完了一个网页深度
            # self.thread_pool.task_join()可代替以下操作,可无法Ctrl-C Interupt
            while self.thread_pool.get_task_left():
                time.sleep(8)
            print 'Depth %d Finish. Totally visited %d links. \n' % (
                self.current_depth, len(self.visited_hrefs))
            log.info('Depth %d Finish. Total visited Links: %d\n' % (
                self.current_depth, len(self.visited_hrefs)))
            self.current_depth += 1
        # After finishing all the tasks, stop this crawling.
        print "all Tasks has finished"
        self._on_all_tasks_finished()
        self.stop()

    def stop(self):
        self.is_crawling = False
        self.thread_pool.stop_threads()

    def get_already_visited_num(self):
        # visitedHrefs保存已经分配给taskQueue的链接,有可能链接还在处理中。
        # 因此真实的已访问链接数为visitedHrefs数减去待访问的链接数
        return len(self.visited_hrefs) - self.thread_pool.get_task_left()

    def _on_all_tasks_finished(self):
        resource_detail_data = unicode(json.dumps(
            self.resource_details.to_json_data(), indent=4))
        hashed_file_name = hashlib.new("md5",
                                       self.resource_details.main_frame_url).hexdigest() + ".json"
        resource_detail_dataPath = os.path.join(self.output, hashed_file_name)
        with io.open(resource_detail_dataPath, 'w') as file:
            file.write(unicode(resource_detail_data))

    def _assigin_current_depth_tasks(self):
        mylock.acquire()
        copied_unvisited_hrefs = deque()
        while self.unvisited_hrefs:
            copied_unvisited_hrefs.append(self.unvisited_hrefs.popleft())
        mylock.release()
        while copied_unvisited_hrefs:
            url = copied_unvisited_hrefs.popleft()
            # 标注该链接已被访问,或即将被访问,防止重复访问相同链接
            self.visited_hrefs.add(url)
            # 向任务队列分配任务
            self.thread_pool.put_task(self._task_handler, url)

    def _task_handler(self, url):
        # 先拿网页源码,再保存,两个都是高阻塞的操作,交给线程处理
        url_fetcher = URLFetcher(url)
        retry = 1
        if url_fetcher.fetch(retry):
            self._save_task_results(url, url_fetcher)
            self._add_unvisited_hrefs(url_fetcher)

    def _save_task_results(self, url, url_fetcher):
        print 'Visited URL : %s \n' % url
        response_headers = url_fetcher.get_response_headers()
        response_detail = ResourceDetail(url,
                                         url_fetcher.request_time,
                                         url_fetcher.response_time,
                                         response_headers)
        mylock.acquire()
        self.resource_details.add_detail(response_detail)
        mylock.release()

    def _add_unvisited_hrefs(self, url_fetcher):
        '''添加未访问的链接。将有效的url放进UnvisitedHrefs列表'''
        # 对链接进行过滤:1.只获取http或https网页;2.保证每个链接只访问一次
        url, page_source = url_fetcher.get_data()
        hrefs = self.get_all_resource_hrefs(url, page_source)
        mylock.acquire()
        for href in hrefs:
            if self._is_http_or_https_protocol(href):
                if not self._is_href_repeated(href):
                    self.unvisited_hrefs.append(href)
        mylock.release()

    def get_all_resource_hrefs(self, url, page_source):
        '''解析html源码,获取页面所有链接。返回链接列表'''
        hrefs = []
        soup = BeautifulSoup(page_source)
        results = soup.find_all(True)

        for tag in results:
            href = None
            if tag.name == 'a':
                continue
            # 必须将链接encode为utf8, 因为中文文件链接如 http://aa.com/文件.pdf
            # 在bs4中不会被自动url编码,从而导致encodeException
            if tag.has_attr('href'):
                href = tag.get('href').encode('utf8')
            elif tag.has_attr('src'):
                href = tag.get('src').encode('utf8')
            if href is not None:
                if not href.startswith('http'):
                    href = urljoin(url, href)  # 处理相对链接的问题
                hrefs.append(href)
        return hrefs

    def _is_http_or_https_protocol(self, href):
        protocal = urlparse(href).scheme
        if protocal == 'http' or protocal == 'https':
            return True
        return False

    def _is_href_repeated(self, href):
        if href in self.visited_hrefs or href in self.unvisited_hrefs:
            return True
        return False