예제 #1
0
    def test_commit(self):
        self.connection_mock.commit = MagicMock()
        db = Database('my_db')
        db.conn = self.connection_mock

        db.commit()

        self.connection_mock.commit.assert_called_once()
예제 #2
0
    def test_close(self):
        close_mock = MagicMock()
        self.connection_mock.close = close_mock
        db = Database('my_db')
        db.conn = self.connection_mock

        db.close()

        close_mock.assert_called_once()
예제 #3
0
    def __init__(self, db_file, num_threads, request_types, process_timeout,
                 scanner_exe, display_progress, scanner_argv):
        self.scan_start_time = int(time.time())
        self.threads = []
        self._th_lock = threading.Lock()
        self._th_lock_db = threading.Lock()
        self.performed_requests = 0
        self._urlpatterns = []
        self._exitcode = 0
        self.scanner_name = self.__class__.__name__.lower()
        self._running = False
        self.settings = self.get_settings()

        #override default settings
        if num_threads: self.settings['num_threads'] = num_threads
        if request_types: self.settings['request_types'] = request_types
        if process_timeout: self.settings['process_timeout'] = process_timeout
        if scanner_exe: self.settings['scanner_exe'] = scanner_exe
        self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")

        self.db = Database(db_file)
        self.id_assessment = self.db.create_assessment(self.scanner_name,
                                                       int(time.time()))
        self.pending_requests = self.db.get_requests(
            self.settings['request_types'])
        self.tot_requests = len(self.pending_requests)
        self._duplicated_requests = []

        urlpatterns = []
        for req in self.pending_requests:
            patt = RequestPattern(req).pattern
            if patt in urlpatterns:
                self._duplicated_requests.append(req.db_id)
            else:
                urlpatterns.append(patt)

        init = self.init(scanner_argv if scanner_argv else [])

        self._running = True
        print "Scanner %s started with %d threads" % (
            self.scanner_name, self.settings['num_threads'])

        for n in range(0, self.settings['num_threads']):
            thread = self.Executor(self)
            self.threads.append(thread)
            thread.start()

        try:
            self.wait_executor(self.threads, display_progress)
        except KeyboardInterrupt:
            print "\nTerminated by user"
            self.kill_threads()

        self.save_assessment()
        sys.exit(self._exitcode)
예제 #4
0
    def init_db(self, dbname, report_name):
        infos = {
            "target": Shared.starturl,
            "scan_date": -1,
            "urls_scanned": -1,
            "scan_time": -1,
            'command_line': " ".join(sys.argv)
        }

        database = Database(dbname, report_name, infos)
        database.create()
        return database
예제 #5
0
파일: crawler.py 프로젝트: Vietworm/htcap
	def init_db(self, dbname, report_name):
		infos = {
			"target": Shared.starturl,
			"scan_date": -1,
			"urls_scanned": -1, 
			"scan_time": -1,
			'command_line': " ".join(sys.argv)
		}

		database = Database(dbname, report_name, infos)
		database.create()
		return database
예제 #6
0
    def test_connect(self):
        sqlite3_mock = MagicMock()
        row_factory_mock = PropertyMock(return_value=None)

        type(sqlite3_mock).row_factory = row_factory_mock
        sqlite3.connect = sqlite3_mock

        db = Database('my_db')
        db.connect()

        sqlite3.connect.assert_called_with('my_db')
        self.assertIsInstance(db.conn, MagicMock)
예제 #7
0
	def __init__(self, db_file, num_threads, request_types, process_timeout, scanner_exe, display_progress, scanner_argv):
		self.scan_start_time = int(time.time())
		self.threads = []
		self._th_lock = threading.Lock()
		self._th_lock_db = threading.Lock()		
		self.performed_requests = 0
		self._urlpatterns = []
		self._exitcode = 0
		self.scanner_name = self.__class__.__name__.lower()
		self._running = False	
		self.settings = self.get_settings()

		#override default settings
		if num_threads: self.settings['num_threads'] = num_threads
		if request_types: self.settings['request_types'] = request_types
		if process_timeout: self.settings['process_timeout'] = process_timeout
		if scanner_exe: self.settings['scanner_exe'] = scanner_exe
		self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")
			


		self.db = Database(db_file)
		self.id_assessment = self.db.create_assessment(self.scanner_name, int(time.time()))
		self.pending_requests = self.db.get_requests(self.settings['request_types'])
		self.tot_requests = len(self.pending_requests)
		self._duplicated_requests = []

		urlpatterns = []
		for req in self.pending_requests:			
			if req.method == "GET":
				pat = self.get_url_pattern(req.url)				
				if pat in urlpatterns:								
					self._duplicated_requests.append(req.db_id)
				else:	
					urlpatterns.append(pat)
		
		init = self.init(scanner_argv if scanner_argv else [])
		
		self._running = True
		print "Scanner %s started with %d threads" % (self.scanner_name, self.settings['num_threads']) 
		
		for n in range(0, self.settings['num_threads']):	
			thread = self.Executor(self)
			self.threads.append(thread)		
			thread.start()

		

		try:
			self.wait_executor(self.threads, display_progress)
		except KeyboardInterrupt:
			print "\nTerminated by user"
			self.kill_threads()

		self.save_assessment()		
		sys.exit(self._exitcode)
예제 #8
0
    def _get_database(outfile_name, output_mode):
        """
        return either an existing database or a new one depending of the given output mode
        :param outfile_name:
        :param output_mode:
        :return:
        """
        file_name = outfile_name
        if output_mode == CRAWLOUTPUT_RENAME:
            file_name = generate_filename(outfile_name,
                                          out_file_overwrite=False)

        elif output_mode == CRAWLOUTPUT_OVERWRITE and os.path.exists(
                file_name):
            os.remove(file_name)

        database = Database(file_name)

        if not os.path.exists(file_name) or (os.path.exists(file_name) and
                                             os.path.getsize(file_name) <= 0):
            database.initialize()

        return database
예제 #9
0
    def setUp(self):
        self.connection_mock = MagicMock()
        self.cursor_mock = MagicMock()
        self.cursor_mock.execute = MagicMock()
        self.cursor_mock.fetchone = MagicMock()
        self.cursor_mock.fetchall = MagicMock(return_value=[])
        self.connection_mock.cursor = MagicMock(return_value=self.cursor_mock)
        self.connect_method_mock = MagicMock()
        self.commit_method_mock = MagicMock()
        self.close_method_mock = MagicMock()

        self.db = Database('my_db')

        self.db.conn = self.connection_mock
        self.db.connect = self.connect_method_mock
        self.db.commit = self.commit_method_mock
        self.db.close = self.close_method_mock
예제 #10
0
파일: scanner.py 프로젝트: zumb08/htcap
	def __init__(self, argv, db_file=None):
		self.scanners = self.get_modules_file(os.path.join(getrealdir(__file__), "scanners"))

		num_threads = None
		request_types = None
		display_progress = True
		modules_path = None
		proxy = None
		cookies = None
		user_agent = None
		extra_headers = None

		try:
			opts, args = getopt.getopt(argv, 'hn:r:vm:p:U:c:E:')
		except getopt.GetoptError as err:
			print(str(err))
			sys.exit(1)

		for o, v in opts:
			if o == '-h':
				self.usage()
				sys.exit(0)
			elif o == '-m':
				modules_path = v
				self.scanners.extend(self.get_modules_file(modules_path))
				sys.path.append(modules_path)

		if len(args) < 2:
			if not db_file or len(args) == 0:
				self.usage()
				sys.exit(1)
			args.append(db_file)


		self.scanner = args[0]
		self.db_file = args[1] if not db_file else db_file

		db = Database(self.db_file)
		crawl_info = db.get_crawl_info()
		try:
			proxy = json.loads(crawl_info['proxy'])
			cookies = json.loads(crawl_info['cookies'])
			extra_headers = json.loads(crawl_info['extra_headers'])
			if not extra_headers:
				extra_headers = {}
			user_agent = crawl_info['user_agent']
		except KeyError:
			print("Unable to read proxy, cookies and user_agent from db.. maybe db created vith an old version . . .")
			pass

		for o, v in opts:
			if o == '-n':
				num_threads = int(v)
			elif o == '-v':
				display_progress = False
			elif o == '-r':
				request_types = v
			elif o == '-p':
				if v == "0":
					proxy = None
				else:
					try:
						proxy = parse_proxy_string(v)
					except Exception as e:
						print(e)
						sys.exit(1)
			elif o == '-c':
				try:
					cookies = parse_cookie_string(v)
				except:
					print("Unable to decode cookies")
					sys.exit(1)
			elif o == '-U':
				user_agent = v
			elif o == '-E':
				if not extra_headers:
					extra_headers = {}
				(hn, hv) = v.split("=", 1)
				extra_headers[hn] = hv


		scanner_argv = args[2:]

		if not self.scanner in self.scanners:
			print("Available scanners are:\n  %s" % "\n  ".join(sorted(self.scanners)))
			sys.exit(1)

		if not os.path.exists(self.db_file):
			print("No such file %s" % self.db_file)
			sys.exit(1)

		try:
			mod = importlib.import_module("core.scan.scanners.%s" % self.scanner)
		except Exception as e:
			if modules_path:
				try:
					mod = importlib.import_module(self.scanner)
				except Exception as e1:
					raise e1
			else:
				raise e

		try:
			run = getattr(mod, self.scanner.title())
			run(self.db_file, num_threads, request_types, display_progress, scanner_argv, proxy, cookies, user_agent, extra_headers)
		except Exception as e:
			print("Error : %s" % e)
			return

		print("Scan finished")
예제 #11
0
    def __init__(self, db_file, num_threads, request_types, display_progress,
                 scanner_argv, proxy, cookies, user_agent, extra_headers):
        self.scan_start_time = int(time.time())
        self.threads = []
        self.lock = threading.Lock()
        self._th_lock = threading.Lock()
        self._th_lock_db = threading.Lock()
        self._th_lock_stdout = threading.Lock()
        self.performed_requests = 0
        self._urlpatterns = []
        self._exitcode = 0
        self._commands = []
        self.scanner_name = self.__class__.__name__.lower()
        self._running = False
        self.settings = self.get_settings()
        #self._type = self.settings['scanner_type'] if 'scanner_type' in self.settings else "external"
        self.exit_requested = False
        self.pause_requested = False
        self._print_queue = {}
        self.display_progress = display_progress
        #override default settings
        if num_threads: self.settings['num_threads'] = num_threads
        if request_types: self.settings['request_types'] = request_types
        #if process_timeout: self.settings['process_timeout'] = process_timeout
        #if scanner_exe: self.settings['scanner_exe'] = scanner_exe

        # if self._type == "external":
        # 	self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")

        self._db = Database(db_file)
        self.id_assessment = self._db.create_assessment(
            self.scanner_name, int(time.time()))
        self.pending_requests = self._db.get_requests(
            self.settings['request_types'])
        self.tot_requests = len(self.pending_requests)
        self._duplicated_requests = []

        self.proxy = proxy
        self.cookies = cookies
        self.user_agent = user_agent
        self.extra_headers = extra_headers

        self.utils = ScannerUtils(self)

        urlpatterns = []
        for req in self.pending_requests:
            patt = RequestPattern(req).pattern
            if patt in urlpatterns:
                self._duplicated_requests.append(req.db_id)
            else:
                urlpatterns.append(patt)

        init = self.init(scanner_argv if scanner_argv else [])

        # if self._type == "external" and not os.path.isfile(self.settings['scanner_exe'][0]):
        # 	raise Exception("scanner_exe not found")

        self._running = True
        print(
            "Scanner %s started with %d threads (^C to pause or change verbosity)"
            % (self.scanner_name, self.settings['num_threads']))

        for n in range(0, self.settings['num_threads']):
            thread = self.Executor(self)
            self.threads.append(thread)
            thread.start()

        self.wait_executor(self.threads)

        if not self.wait_threads_exit():
            self._th_lock.acquire()
            for cmd in self._commands:
                if cmd:
                    cmd.kill()
            self._th_lock.release()
            os._exit(1)

        self.end()

        self.save_assessment()
예제 #12
0
class BaseScanner:
    def __init__(self, db_file, num_threads, request_types, display_progress,
                 scanner_argv, proxy, cookies, user_agent, extra_headers):
        self.scan_start_time = int(time.time())
        self.threads = []
        self.lock = threading.Lock()
        self._th_lock = threading.Lock()
        self._th_lock_db = threading.Lock()
        self._th_lock_stdout = threading.Lock()
        self.performed_requests = 0
        self._urlpatterns = []
        self._exitcode = 0
        self._commands = []
        self.scanner_name = self.__class__.__name__.lower()
        self._running = False
        self.settings = self.get_settings()
        #self._type = self.settings['scanner_type'] if 'scanner_type' in self.settings else "external"
        self.exit_requested = False
        self.pause_requested = False
        self._print_queue = {}
        self.display_progress = display_progress
        #override default settings
        if num_threads: self.settings['num_threads'] = num_threads
        if request_types: self.settings['request_types'] = request_types
        #if process_timeout: self.settings['process_timeout'] = process_timeout
        #if scanner_exe: self.settings['scanner_exe'] = scanner_exe

        # if self._type == "external":
        # 	self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")

        self._db = Database(db_file)
        self.id_assessment = self._db.create_assessment(
            self.scanner_name, int(time.time()))
        self.pending_requests = self._db.get_requests(
            self.settings['request_types'])
        self.tot_requests = len(self.pending_requests)
        self._duplicated_requests = []

        self.proxy = proxy
        self.cookies = cookies
        self.user_agent = user_agent
        self.extra_headers = extra_headers

        self.utils = ScannerUtils(self)

        urlpatterns = []
        for req in self.pending_requests:
            patt = RequestPattern(req).pattern
            if patt in urlpatterns:
                self._duplicated_requests.append(req.db_id)
            else:
                urlpatterns.append(patt)

        init = self.init(scanner_argv if scanner_argv else [])

        # if self._type == "external" and not os.path.isfile(self.settings['scanner_exe'][0]):
        # 	raise Exception("scanner_exe not found")

        self._running = True
        print(
            "Scanner %s started with %d threads (^C to pause or change verbosity)"
            % (self.scanner_name, self.settings['num_threads']))

        for n in range(0, self.settings['num_threads']):
            thread = self.Executor(self)
            self.threads.append(thread)
            thread.start()

        self.wait_executor(self.threads)

        if not self.wait_threads_exit():
            self._th_lock.acquire()
            for cmd in self._commands:
                if cmd:
                    cmd.kill()
            self._th_lock.release()
            os._exit(1)

        self.end()

        self.save_assessment()

    def end(self):
        pass

    def get_settings(self):
        return dict(
            request_types="xhr,fetch,link,redirect,form,json",
            num_threads=10,
            #process_timeout = 120,
            #scanner_exe = ""
        )

    def wait_executor(self, threads):
        executor_done = False
        pb = Progressbar(self.scan_start_time, "requests scanned")

        while not executor_done:
            try:
                executor_done = True
                for th in threads:
                    if self.display_progress:
                        self._th_lock.acquire()
                        scanned = self.performed_requests
                        pending = len(self.pending_requests)
                        tot = self.tot_requests
                        self._th_lock.release()
                        pb.out(tot, scanned)
                    else:
                        self._th_lock_stdout.acquire()
                        for id in self._print_queue:
                            for out in self._print_queue[id]:
                                print(out)
                        self._print_queue = {}
                        self._th_lock_stdout.release()
                    if th.isAlive():
                        executor_done = False
                    th.join(1)
            except KeyboardInterrupt:
                try:
                    self._th_lock.release()
                    self._th_lock_stdout.release()
                except:
                    pass
                self.pause_threads(True)
                if not self.get_runtime_command():
                    print("Exiting . . .")
                    self.request_exit()
                    return
                print("Scan is running")
                self.pause_threads(False)
        if self.display_progress:
            print("")

    def get_runtime_command(self):
        while True:
            print("\nScan is paused. Choose what to do:\n"
                  "   r    resume scan\n"
                  "   v    verbose mode\n"
                  "   p    show progress bar\n"
                  "Hit ctrl-c again to exit\n")
            try:
                ui = input("> ").strip()
            except KeyboardInterrupt:
                print("")
                return False
            if ui == "r":
                break
            elif ui == "v":
                self.display_progress = False
                break
            elif ui == "p":
                self.display_progress = True
                break

            print(" ")

        return True

    def request_exit(self):
        self.exit_requested = True
        self.pause_requested = False
        self._th_lock.acquire()
        for th in self.threads:
            if th.isAlive():
                th.exit = True
                th.pause = False
                for cmd in self._commands:
                    if cmd:
                        cmd.terminate()
        self._th_lock.release()

    def wait_threads_exit(self):
        waittime = 0.0
        msg = ""
        while True:
            try:
                at = 0
                for th in self.threads:
                    if th.isAlive(): at += 1
                if at == 0:
                    break
                if waittime > 2:
                    stdoutw("\b" * len(msg))
                    msg = "Waiting %d requests to be completed" % at
                    stdoutw(msg)
                waittime += 0.1
                time.sleep(0.1)
            except KeyboardInterrupt:
                try:
                    die = input("\nForce exit? [y/N] ").strip()
                    if die == "y":
                        return False
                except KeyboardInterrupt:
                    return True
        print("")
        return True

    def pause_threads(self, pause):
        self.pause_requested = pause
        self._th_lock.acquire()
        for th in self.threads:
            if th.isAlive(): th.pause = pause
        self._th_lock.release()

    def _sprint(self, id, str):
        self._th_lock_stdout.acquire()
        if not id in self._print_queue or not self._print_queue[id]:
            self._print_queue[id] = []
        self._print_queue[id].append(str)
        self._th_lock_stdout.release()

    def exit(self, code):
        if self._running:
            self._th_lock.acquire()
            self._exitcode = code
            self._th_lock.release()
            self.kill_threads()
            print("kill thread")
            print("")
        else:
            sys.exit(code)

    def db(self, method, params):
        self._th_lock_db.acquire()
        m = getattr(self._db, method)
        ret = m(*params)
        self._th_lock_db.release()
        return ret

    def save_vulnerability(self, request, type, description):
        self._th_lock_db.acquire()
        self._db.insert_vulnerability(self.id_assessment, request.db_id, type,
                                      description)
        self._th_lock_db.release()

    def save_vulnerabilities(self, request, vulnerabilities):
        self._th_lock_db.acquire()
        self._db.insert_vulnerabilities(self.id_assessment, request.db_id,
                                        vulnerabilities)
        self._th_lock_db.release()

    def save_assessment(self):
        self._th_lock_db.acquire()
        self._db.save_assessment(self.id_assessment, int(time.time()))
        self._th_lock_db.release()

    def is_request_duplicated(self, request):
        return request.db_id in self._duplicated_requests

    # class utils:
    # 	@staticmethod
    # 	def send(url, method=None, data=None, cookies=None, user_agent=None, proxy=None, extra_headers=None, req_timeout=5, ignore_errors=False):
    # 		if not method:
    # 			method = METHOD_GET
    # 		req = Request(REQTYPE_LINK, method, url)
    # 		http = HttpGet(req, req_timeout, proxy=proxy, useragent=user_agent, extra_headers=extra_headers)
    # 		return  http.send_request(method=method, url=url, data=data, cookies=cookies, ignore_errors=ignore_errors)

    # 	@staticmethod
    # 	def strip_html_tags(html):
    # 		return strip_html_tags(html)

    # 	@staticmethod
    # 	def execmd(cmd, params=None, timeout=None):
    # 		return execmd(cmd, params, timeout)

    class Executor(threading.Thread):
        def __init__(self, scanner):
            threading.Thread.__init__(self)
            self.scanner = scanner
            self.exit = False
            self.pause = False
            self.thread_uuid = uuid.uuid4()
            self.tmp_dir = "%s%shtcap_tempdir-%s" % (tempfile.gettempdir(),
                                                     os.sep, self.thread_uuid)
            os.makedirs(self.tmp_dir, 0o700)

        def inc_counter(self):
            self.scanner._th_lock.acquire()
            self.scanner.performed_requests += 1
            self.scanner._th_lock.release()

        def run(self):
            req = None
            while True:

                self.scanner._th_lock.acquire()
                if self.exit == True or len(
                        self.scanner.pending_requests) == 0:
                    self.scanner._th_lock.release()
                    shutil.rmtree(self.tmp_dir)
                    return
                if self.pause == True:
                    self.scanner._th_lock.release()
                    time.sleep(1)
                    continue

                req = self.scanner.pending_requests.pop()
                self.scanner._th_lock.release()

                #if self.scanner._type == "native":
                sc = self.scanner.Scan(self, req)
                sc.run()
                self.inc_counter()
예제 #13
0
class BaseScanner:
	def __init__(self, db_file, num_threads, request_types, process_timeout, scanner_exe, display_progress, scanner_argv):
		self.scan_start_time = int(time.time())
		self.threads = []
		self._th_lock = threading.Lock()
		self._th_lock_db = threading.Lock()		
		self.performed_requests = 0
		self._urlpatterns = []
		self._exitcode = 0
		self.scanner_name = self.__class__.__name__.lower()
		self._running = False	
		self.settings = self.get_settings()

		#override default settings
		if num_threads: self.settings['num_threads'] = num_threads
		if request_types: self.settings['request_types'] = request_types
		if process_timeout: self.settings['process_timeout'] = process_timeout
		if scanner_exe: self.settings['scanner_exe'] = scanner_exe
		self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")
			


		self.db = Database(db_file)
		self.id_assessment = self.db.create_assessment(self.scanner_name, int(time.time()))
		self.pending_requests = self.db.get_requests(self.settings['request_types'])
		self.tot_requests = len(self.pending_requests)
		self._duplicated_requests = []

		urlpatterns = []
		for req in self.pending_requests:
			patt = RequestPattern(req).pattern
			if patt in urlpatterns:
				self._duplicated_requests.append(req.db_id)
			else:	
				urlpatterns.append(patt)
		
		init = self.init(scanner_argv if scanner_argv else [])
		
		self._running = True
		print "Scanner %s started with %d threads" % (self.scanner_name, self.settings['num_threads']) 
		
		for n in range(0, self.settings['num_threads']):	
			thread = self.Executor(self)
			self.threads.append(thread)		
			thread.start()	

		try:
			self.wait_executor(self.threads, display_progress)
		except KeyboardInterrupt:
			print "\nTerminated by user"
			self.kill_threads()

		self.save_assessment()		
		sys.exit(self._exitcode)


	def get_settings(self):
		return dict(			
			request_types = "xhr,link,redirect,form,json",
			num_threads = 10,
			process_timeout = 120,
			scanner_exe = ""
		)


	def get_cmd(self, url, outfile):
		cmd = []
		return cmd


	def scanner_executed(self, id_parent, out, err, out_file):
		return


	def wait_executor(self, threads, display_progress):
		executor_done = False
		while not executor_done:
			executor_done = True
			for th in threads:
				if th.isAlive():
					executor_done = False
				th.join(1)

				if display_progress:
					self._th_lock.acquire()
					scanned = self.performed_requests
					pending = len(self.pending_requests)
					tot = self.tot_requests
					self._th_lock.release()

					print_progressbar(tot, scanned, self.scan_start_time, "requests scanned")
		if display_progress:
			print ""


	def kill_threads(self):
		self._th_lock.acquire()
		for th in self.threads:
			if th.isAlive(): th.exit = True
		self._th_lock.release()
		

	def exit(self, code):
		if self._running:			
			self._th_lock.acquire()
			self._exitcode = code
			self._th_lock.release()			
			self.kill_threads()
			print "kill thread"
			print ""
		else :
			sys.exit(code)


	def save_vulnerability(self, request, type, description):
		self._th_lock_db.acquire()		
		self.db.insert_vulnerability(self.id_assessment, request.db_id, type, description)		
		self._th_lock_db.release()


	def save_assessment(self):
		self._th_lock_db.acquire()		
		self.db.save_assessment(self.id_assessment, int(time.time()))		
		self._th_lock_db.release()


	def is_request_duplicated(self, request):		
		return request.db_id in self._duplicated_requests


	class Executor(threading.Thread):
		
		def __init__(self, scanner):
			threading.Thread.__init__(self)
			self.scanner = scanner
			self.exit = False		
			self.thread_uuid = uuid.uuid4()
			self.tmp_dir = "%s%shtcap_tempdir-%s" % (tempfile.gettempdir(), os.sep, self.thread_uuid)				
			os.makedirs(self.tmp_dir, 0700)
			
		def inc_counter(self):
			self.scanner._th_lock.acquire()
			self.scanner.performed_requests += 1
			self.scanner._th_lock.release()

		def run(self):
			req = None
			while True:
				
				self.scanner._th_lock.acquire()
				if self.exit == True or len(self.scanner.pending_requests) == 0:
					self.scanner._th_lock.release()
					shutil.rmtree(self.tmp_dir)					
					return

				req = self.scanner.pending_requests.pop()
									
				self.scanner._th_lock.release()						

				
				cmd_options = self.scanner.get_cmd(req, self.tmp_dir)
				if cmd_options == False: 
					self.inc_counter()
					continue

				cmd = self.scanner.settings['scanner_exe'] + cmd_options
				

				exe = CommandExecutor(cmd, True)
				out, err = exe.execute(self.scanner.settings['process_timeout'])
				# if err: print "\nError: \n%s\n%s\n%s\n" % (err," ".join(cmd),out)
				
				self.inc_counter()

				self.scanner.scanner_executed(req, out,err, self.tmp_dir, cmd)
예제 #14
0
class BaseScanner:
    def __init__(self, db_file, num_threads, request_types, process_timeout,
                 scanner_exe, display_progress, scanner_argv):
        self.scan_start_time = int(time.time())
        self.threads = []
        self._th_lock = threading.Lock()
        self._th_lock_db = threading.Lock()
        self.performed_requests = 0
        self._urlpatterns = []
        self._exitcode = 0
        self.scanner_name = self.__class__.__name__.lower()
        self._running = False
        self.settings = self.get_settings()

        #override default settings
        if num_threads: self.settings['num_threads'] = num_threads
        if request_types: self.settings['request_types'] = request_types
        if process_timeout: self.settings['process_timeout'] = process_timeout
        if scanner_exe: self.settings['scanner_exe'] = scanner_exe
        self.settings['scanner_exe'] = self.settings['scanner_exe'].split(" ")

        self.db = Database(db_file)
        self.id_assessment = self.db.create_assessment(self.scanner_name,
                                                       int(time.time()))
        self.pending_requests = self.db.get_requests(
            self.settings['request_types'])
        self.tot_requests = len(self.pending_requests)
        self._duplicated_requests = []

        urlpatterns = []
        for req in self.pending_requests:
            patt = RequestPattern(req).pattern
            if patt in urlpatterns:
                self._duplicated_requests.append(req.db_id)
            else:
                urlpatterns.append(patt)

        init = self.init(scanner_argv if scanner_argv else [])

        self._running = True
        print "Scanner %s started with %d threads" % (
            self.scanner_name, self.settings['num_threads'])

        for n in range(0, self.settings['num_threads']):
            thread = self.Executor(self)
            self.threads.append(thread)
            thread.start()

        try:
            self.wait_executor(self.threads, display_progress)
        except KeyboardInterrupt:
            print "\nTerminated by user"
            self.kill_threads()

        self.save_assessment()
        sys.exit(self._exitcode)

    def get_settings(self):
        return dict(request_types="xhr,link,redirect,form,json",
                    num_threads=10,
                    process_timeout=120,
                    scanner_exe="")

    def get_cmd(self, url, outfile):
        cmd = []
        return cmd

    def scanner_executed(self, id_parent, out, err, out_file):
        return

    def wait_executor(self, threads, display_progress):
        executor_done = False
        while not executor_done:
            executor_done = True
            for th in threads:
                if th.isAlive():
                    executor_done = False
                th.join(1)

                if display_progress:
                    self._th_lock.acquire()
                    scanned = self.performed_requests
                    pending = len(self.pending_requests)
                    tot = self.tot_requests
                    self._th_lock.release()

                    print_progressbar(tot, scanned, self.scan_start_time,
                                      "requests scanned")
        if display_progress:
            print ""

    def kill_threads(self):
        self._th_lock.acquire()
        for th in self.threads:
            if th.isAlive(): th.exit = True
        self._th_lock.release()

    def exit(self, code):
        if self._running:
            self._th_lock.acquire()
            self._exitcode = code
            self._th_lock.release()
            self.kill_threads()
            print "kill thread"
            print ""
        else:
            sys.exit(code)

    def save_vulnerability(self, request, type, description):
        self._th_lock_db.acquire()
        self.db.insert_vulnerability(self.id_assessment, request.db_id, type,
                                     description)
        self._th_lock_db.release()

    def save_assessment(self):
        self._th_lock_db.acquire()
        self.db.save_assessment(self.id_assessment, int(time.time()))
        self._th_lock_db.release()

    def is_request_duplicated(self, request):
        return request.db_id in self._duplicated_requests

    class Executor(threading.Thread):
        def __init__(self, scanner):
            threading.Thread.__init__(self)
            self.scanner = scanner
            self.exit = False
            self.thread_uuid = uuid.uuid4()
            self.tmp_dir = "%s%shtcap_tempdir-%s" % (tempfile.gettempdir(),
                                                     os.sep, self.thread_uuid)
            os.makedirs(self.tmp_dir, 0700)

        def inc_counter(self):
            self.scanner._th_lock.acquire()
            self.scanner.performed_requests += 1
            self.scanner._th_lock.release()

        def run(self):
            req = None
            while True:

                self.scanner._th_lock.acquire()
                if self.exit == True or len(
                        self.scanner.pending_requests) == 0:
                    self.scanner._th_lock.release()
                    shutil.rmtree(self.tmp_dir)
                    return

                req = self.scanner.pending_requests.pop()

                self.scanner._th_lock.release()

                cmd_options = self.scanner.get_cmd(req, self.tmp_dir)
                if cmd_options == False:
                    self.inc_counter()
                    continue

                cmd = self.scanner.settings['scanner_exe'] + cmd_options

                exe = CommandExecutor(cmd, True)
                out, err = exe.execute(
                    self.scanner.settings['process_timeout'])
                # if err: print "\nError: \n%s\n%s\n%s\n" % (err," ".join(cmd),out)

                self.inc_counter()

                self.scanner.scanner_executed(req, out, err, self.tmp_dir, cmd)
예제 #15
0
    def test___str__(self):
        db = Database('my_db')

        self.assertEqual(str(db), 'my_db')
예제 #16
0
    def test_constructor(self):
        db = Database('my_db')

        self.assertEqual(db.dbname, 'my_db')
        self.assertEqual(db.conn, None)