Пример #1
0
class Module(ToolTemplate):
    name = "DNSRecon"
    binary_name = "dnsrecon"
    """
    This module runs DNSRecon on a domain or set of domains. This will extract found DNS entries.
    It can also run over IP ranges, looking for additional domains in the PTR records.

    DNSRecon can be installed from https://github.com/darkoperator/dnsrecon
    """
    def __init__(self, db):
        self.db = db
        self.BaseDomain = BaseDomainRepository(db, self.name)
        self.Domain = DomainRepository(db, self.name)
        self.ScopeCIDR = ScopeCIDRRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()
        self.options.add_argument("-d",
                                  "--domain",
                                  help="Target domain for dnsRecon")
        self.options.add_argument("-f",
                                  "--file",
                                  help="Import domains from file")
        self.options.add_argument(
            "-i",
            "--import_database",
            help="Import domains from database",
            action="store_true",
        )
        self.options.add_argument("-r",
                                  "--range",
                                  help="Range to scan for PTR records")
        self.options.add_argument(
            "-R",
            "--import_range",
            help="Import CIDRs from in-scope ranges in database",
            action="store_true",
        )
        self.options.add_argument("--rescan",
                                  help="Rescan domains already scanned",
                                  action="store_true")
        # self.options.add_argument('--import_output_xml', help="Import XML file")
        # self.options.add_argument('--import_output_json', help="Import json file")

    def get_targets(self, args):

        targets = []
        if args.domain:
            created, domain = self.BaseDomain.find_or_create(
                domain=args.domain, passive_scope=True)
            targets.append(domain.domain)

        elif args.file:
            domains = open(args.file).read().split("\n")
            for d in domains:
                if d:
                    created, domain = self.BaseDomain.find_or_create(
                        domain=d, passive_scope=True)
                    targets.append(domain.domain)

        elif args.import_database:
            if args.rescan:
                domains = self.BaseDomain.all(scope_type="passive")
            else:
                domains = self.BaseDomain.all(scope_type="passive",
                                              tool=self.name)
            for domain in domains:
                targets.append(domain.domain)

        elif args.range:
            targets.append(args.range)

        elif args.import_range:
            if args.rescan:
                cidrs = self.ScopeCIDR.all()
            else:
                cidrs = self.ScopeCIDR.all(tool=self.name)

            for cidr in cidrs:
                targets.append(cidr.cidr)

        if args.output_path[0] == "/":
            self.path = os.path.join(self.base_config["PROJECT"]["base_path"],
                                     args.output_path[1:])
        else:
            self.path = os.path.join(self.base_config["PROJECT"]["base_path"],
                                     args.output_path)

        if not os.path.exists(self.path):
            os.makedirs(self.path)

        res = []
        for t in targets:
            res.append({
                "target":
                t,
                "output":
                os.path.join(self.path,
                             t.replace("/", "_") + ".json"),
            })

        return res

    def build_cmd(self, args):
        command = self.binary

        if args.domain or args.file or args.import_database:

            command += " -d {target} -j {output} "

        else:
            command += " -s -r {target} -j {output} "

        if args.tool_args:
            command += args.tool_args

        return command

    def process_output(self, cmds):

        for c in cmds:
            target = c["target"]
            output_path = c["output"]

            try:
                res = json.loads(open(output_path).read())
            except IOError:
                display_error("DnsRecon failed for {}".format(target))

            if " -d " in res[0]["arguments"]:
                created, dbrec = self.Domain.find_or_create(domain=target)
                dbrec.dns = res
                dbrec.save()

            for record in res:
                domain = None
                ip = None
                if record.get("type") == "A" or record.get("type") == "PTR":
                    domain = record.get("name").lower().replace("www.", "")
                    ip = record.get("address")

                elif record.get("type") == "MX":
                    domain = record.get("exchange").lower().replace("www.", "")

                elif record.get("type") == "SRV" or record.get("type" == "NS"):
                    domain = record.get("target").lower().replace("www.", "")

                elif record.get("type") == "SOA":
                    domain = record.get("mname").lower().replace("www.", "")

                if domain:
                    created, domain_obj = self.Domain.find_or_create(
                        domain=domain)
                    if ip:
                        created, ip_obj = self.IPAddress.find_or_create(
                            ip_address=ip)
                        domain_obj.ip_addresses.append(ip_obj)
                        domain_obj.save()

        self.Domain.commit()
Пример #2
0
class Module(ModuleTemplate):
    """
    The Shodan module will either iterate through Shodan search results from net:<cidr>
    for all scoped CIDRs, or a custom search query. The resulting IPs and ports will be
    added to the database, along with a dictionary object of the API results.

    """

    name = "ShodanImport"

    def __init__(self, db):
        self.db = db
        self.Port = PortRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)
        self.ScopeCidr = ScopeCIDRRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()

        self.options.add_argument("-k",
                                  "--api_key",
                                  help="API Key for accessing Shodan")
        self.options.add_argument(
            "-s", "--search", help="Custom search string (will use credits)")
        self.options.add_argument(
            "-i",
            "--import_db",
            help="Import scoped IPs from the database",
            action="store_true",
        )
        self.options.add_argument("--rescan",
                                  help="Rescan CIDRs already processed",
                                  action="store_true")
        self.options.add_argument("--fast",
                                  help="Use 'net' filter. (May use credits)",
                                  action="store_true")
        self.options.add_argument(
            "--cidr_only",
            help="Import only CIDRs from database (not individual IPs)",
            action="store_true",
        )

    def run(self, args):

        if not args.api_key:
            display_error("You must supply an API key to use shodan!")
            return

        if args.search:
            ranges = [args.search]

        if args.import_db:
            ranges = []
            if args.rescan:
                if args.fast:
                    ranges += [
                        "net:{}".format(c.cidr) for c in self.ScopeCidr.all()
                    ]
                else:

                    cidrs = [c.cidr for c in self.ScopeCidr.all()]
                    for c in cidrs:
                        ranges += [str(i) for i in IPNetwork(c)]
                if not args.cidr_only:
                    ranges += [
                        "{}".format(i.ip_address)
                        for i in self.IPAddress.all(scope_type="active")
                    ]
            else:
                if args.fast:
                    ranges += [
                        "net:{}".format(c.cidr)
                        for c in self.ScopeCidr.all(tool=self.name)
                    ]
                else:
                    cidrs = [
                        c.cidr for c in self.ScopeCidr.all(tool=self.name)
                    ]
                    for c in cidrs:
                        ranges += [str(i) for i in IPNetwork(c)]
                if not args.cidr_only:
                    ranges += [
                        "{}".format(i.ip_address)
                        for i in self.IPAddress.all(scope_type="active",
                                                    tool=self.name)
                    ]

        api_host_url = "https://api.shodan.io/shodan/host/{}?key={}"
        api_search_url = (
            "https://api.shodan.io/shodan/host/search?key={}&query={}&page={}")
        for r in ranges:

            time.sleep(1)
            if ":" in r:
                display("Doing Shodan search: {}".format(r))
                try:
                    results = json.loads(
                        requests.get(api_search_url.format(args.api_key, r,
                                                           1)).text)
                    if results.get("error") and "request timed out" in results[
                            "error"]:
                        display_warning(
                            "Timeout occurred on Shodan's side.. trying again in 5 seconds."
                        )
                        results = json.loads(
                            requests.get(
                                api_search_url.format(args.api_key, r,
                                                      1)).text)
                except Exception as e:
                    display_error("Something went wrong: {}".format(e))
                    next

                total = len(results["matches"])
                matches = []
                i = 1
                while total > 0:
                    display("Adding {} results from page {}".format(total, i))
                    matches += results["matches"]
                    i += 1
                    try:
                        time.sleep(1)
                        results = json.loads(
                            requests.get(
                                api_search_url.format(args.api_key, r,
                                                      i)).text)
                        if (results.get("error") and "request timed out"
                                in results["error"]  # noqa: W503
                            ):
                            display_warning(
                                "Timeout occurred on Shodan's side.. trying again in 5 seconds."
                            )
                            results = json.loads(
                                requests.get(
                                    api_search_url.format(args.api_key, r,
                                                          1)).text)

                        total = len(results["matches"])

                    except Exception as e:
                        display_error("Something went wrong: {}".format(e))
                        total = 0
                        pdb.set_trace()

                for res in matches:
                    ip_str = res["ip_str"]
                    port_str = res["port"]
                    transport = res["transport"]

                    display("Processing IP: {} Port: {}/{}".format(
                        ip_str, port_str, transport))

                    created, IP = self.IPAddress.find_or_create(
                        ip_address=ip_str)
                    IP.meta["shodan_data"] = results

                    created, port = self.Port.find_or_create(
                        ip_address=IP, port_number=port_str, proto=transport)
                    if created:
                        svc = ""

                        if res.get("ssl", False):
                            svc = "https"
                        elif res.get("http", False):
                            svc = "http"

                        else:
                            svc = ""

                        port.service_name = svc
                    port.status = "open"
                    port.meta["shodan_data"] = res
                    port.save()
            else:

                try:
                    results = json.loads(
                        requests.get(api_host_url.format(r,
                                                         args.api_key)).text)
                except Exception as e:
                    display_error("Something went wrong: {}".format(e))
                    next
                # pdb.set_trace()
                if results.get("data", False):

                    display("{} results found for: {}".format(
                        len(results["data"]), r))

                    for res in results["data"]:
                        ip_str = res["ip_str"]
                        port_str = res["port"]
                        transport = res["transport"]
                        display("Processing IP: {} Port: {}/{}".format(
                            ip_str, port_str, transport))
                        created, IP = self.IPAddress.find_or_create(
                            ip_address=ip_str)
                        IP.meta["shodan_data"] = results

                        created, port = self.Port.find_or_create(
                            ip_address=IP,
                            port_number=port_str,
                            proto=transport)

                        if created:
                            svc = ""

                            if res.get("ssl", False):
                                svc = "https"
                            elif res.get("http", False):
                                svc = "http"

                            else:
                                svc = ""

                            port.service_name = svc
                        port.status = "open"
                        port.meta["shodan_data"] = res
                        port.save()

        self.IPAddress.commit()
Пример #3
0
 def __init__(self, db):
     self.db = db
     self.Port = PortRepository(db, self.name)
     self.IPAddress = IPRepository(db, self.name)
     self.ScopeCidr = ScopeCIDRRepository(db, self.name)
Пример #4
0
 def __init__(self, db):
     self.db = db
     self.BaseDomain = BaseDomainRepository(db, self.name)
     self.Domain = DomainRepository(db, self.name)
     self.ScopeCIDR = ScopeCIDRRepository(db, self.name)
     self.IPAddress = IPRepository(db, self.name)
Пример #5
0
class Module(ModuleTemplate):
    """
    The Shodan module will either iterate through Shodan search results from net:<cidr>
    for all scoped CIDRs, or a custom search query. The resulting IPs and ports will be
    added to the database, along with a dictionary object of the API results.

    """

    name = "ShodanImport"

    def __init__(self, db):
        self.db = db
        self.Port = PortRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)
        self.ScopeCidr = ScopeCIDRRepository(db, self.name)
        self.Domain = DomainRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()

        self.options.add_argument("-k",
                                  "--api_key",
                                  help="API Key for accessing Shodan")
        self.options.add_argument(
            "-s", "--search", help="Custom search string (will use credits)")
        self.options.add_argument(
            "-i",
            "--import_db",
            help="Import scoped IPs from the database",
            action="store_true",
        )
        self.options.add_argument("--rescan",
                                  help="Rescan CIDRs already processed",
                                  action="store_true")
        self.options.add_argument("--fast",
                                  help="Use 'net' filter. (May use credits)",
                                  action="store_true")
        self.options.add_argument(
            "--cidr_only",
            help="Import only CIDRs from database (not individual IPs)",
            action="store_true",
        )

        self.options.add_argument("--target",
                                  "-t",
                                  help="Scan a specific CIDR/IP")

    def run(self, args):

        ranges = []
        cidrs = []
        ips = []
        search = []
        if not args.api_key:
            display_error("You must supply an API key to use shodan!")
            return

        if args.search:
            search = [args.search]

        if args.import_db:
            if args.rescan:
                if args.fast:
                    search += [
                        "net:{}".format(c.cidr) for c in self.ScopeCidr.all()
                    ]
                else:
                    cidrs += [c.cidr for c in self.ScopeCidr.all()]

                if not args.cidr_only:
                    ips += [
                        "{}".format(i.ip_address)
                        for i in self.IPAddress.all(scope_type="active")
                    ]
            else:
                if args.fast:
                    search += [
                        "net:{}".format(c.cidr)
                        for c in self.ScopeCidr.all(tool=self.name)
                    ]
                else:
                    cidrs += [
                        c.cidr for c in self.ScopeCidr.all(tool=self.name)
                    ]

                if not args.cidr_only:
                    ips += [
                        "{}".format(i.ip_address)
                        for i in self.IPAddress.all(scope_type="active",
                                                    tool=self.name)
                    ]
        if args.target:

            if '/' not in args.target:
                ips += [args.target]
            elif args.fast:
                cidrs += ["net:{}".format(args.target)]
            else:
                cidrs += [args.target]

        for c in cidrs:
            ranges += [str(i) for i in IPNetwork(c)]

        ranges += ips
        ranges += search

        display(
            "Doing a total of {} queries. Estimated time: {} days, {} hours, {} minutes and {} seconds."
            .format(len(ranges), int(len(ranges) / 24.0 / 60.0 / 60.0),
                    int(len(ranges) / 60.0 / 60.0) % 60,
                    int(len(ranges) / 60.0) % 60,
                    len(ranges) % 60))

        for c in cidrs:
            ranges = [str(i) for i in IPNetwork(c)]
            display(
                "Processing {} IPs. Estimated time: {} days, {} hours, {} minutes and {} seconds."
                .format(c, int(len(ranges) / 24.0 / 60.0 / 60.0),
                        int(len(ranges) / 60.0 / 60.0) % 60,
                        int(len(ranges) / 60.0) % 60,
                        len(ranges) % 60))
            for r in ranges:

                self.get_shodan(r, args)

            created, cd = self.ScopeCidr.find_or_create(cidr=c)
            if created:
                cd.delete()
            else:
                cd.set_tool(self.name)
            self.ScopeCidr.commit()
        display(
            "Processing {} IPs. Estimated time: {} days, {} hours, {} minutes and {} seconds."
            .format(len(ips), int(len(ranges) / 24.0 / 60.0 / 60.0),
                    int(len(ranges) / 60.0 / 60.0) % 60,
                    int(len(ranges) / 60.0) % 60,
                    len(ranges) % 60))
        for i in ips:
            self.get_shodan(i, args)

            created, ip = self.IPAddress.find_or_create(ip_address=i)
            if created:
                ip.delete()
            else:
                ip.set_tool(self.name)
            self.IPAddress.commit()

        for s in search:
            self.get_shodan(s, args)

            if s[:4] == "net:":
                created, cd = self.ScopeCidr.find_or_create(cidr=s[4:])
                if created:
                    cd.delete()
                else:
                    cd.set_tool(self.name)
                self.ScopeCidr.commit()

    def get_shodan(self, r, args):

        api_host_url = "https://api.shodan.io/shodan/host/{}?key={}"
        api_search_url = (
            "https://api.shodan.io/shodan/host/search?key={}&query={}&page={}")
        time.sleep(1)
        if ":" in r:
            display("Doing Shodan search: {}".format(r))
            try:
                results = json.loads(
                    requests.get(api_search_url.format(args.api_key, r,
                                                       1)).text)
                if results.get(
                        "error") and "request timed out" in results["error"]:
                    display_warning(
                        "Timeout occurred on Shodan's side.. trying again in 5 seconds."
                    )
                    results = json.loads(
                        requests.get(api_search_url.format(args.api_key, r,
                                                           1)).text)
            except Exception as e:
                display_error("Something went wrong: {}".format(e))
                next

            total = len(results["matches"])
            matches = []
            i = 1
            # pdb.set_trace()
            while total > 0:
                display("Adding {} results from page {}".format(total, i))
                matches += results["matches"]
                i += 1
                try:
                    time.sleep(1)
                    results = json.loads(
                        requests.get(api_search_url.format(args.api_key, r,
                                                           i)).text)
                    if (results.get("error") and "request timed out"
                            in results["error"]  # noqa: W503
                        ):
                        display_warning(
                            "Timeout occurred on Shodan's side.. trying again in 5 seconds."
                        )
                        results = json.loads(
                            requests.get(
                                api_search_url.format(args.api_key, r,
                                                      1)).text)

                    total = len(results["matches"])

                except Exception as e:
                    display_error("Something went wrong: {}".format(e))
                    total = 0
                    pdb.set_trace()
            domains = []

            for res in matches:
                ip_str = res["ip_str"]
                port_str = res["port"]
                transport = res["transport"]

                display("Processing IP: {} Port: {}/{}".format(
                    ip_str, port_str, transport))

                created, IP = self.IPAddress.find_or_create(ip_address=ip_str)
                IP.meta["shodan_data"] = results

                created, port = self.Port.find_or_create(ip_address=IP,
                                                         port_number=port_str,
                                                         proto=transport)
                if created:
                    svc = ""

                    if res.get("ssl", False):
                        svc = "https"
                    elif res.get("http", False):
                        svc = "http"

                    else:
                        svc = ""

                    port.service_name = svc
                port.status = "open"
                port.meta["shodan_data"] = res
                port.save()

                if res.get("ssl", {}).get('cert', {}).get('extensions'):
                    for d in res['ssl']['cert']['extensions']:
                        if d['name'] == 'subjectAltName':
                            domains += get_domains_from_data(d['name'])

                if res.get("ssl", {}).get('cert', {}).get(
                        'subject', {}
                ).get('CN') and '*' not in res['ssl']['cert']['subject']['CN']:
                    domains.append(res['ssl']['cert']['subject']['CN'])

                if res.get('hostnames'):
                    domains += res['hostnames']

            for d in list(set(domains)):
                display("Adding discovered domain {}".format(only_valid(d)))
                created, domain = self.Domain.find_or_create(
                    domain=only_valid(d))

        else:
            display("Searching for {}".format(r))
            try:
                results = json.loads(
                    requests.get(api_host_url.format(r, args.api_key)).text)
            except Exception as e:
                display_error("Something went wrong: {}".format(e))
                next
            # pdb.set_trace()
            if results.get("data", False):

                display("{} results found for: {}".format(
                    len(results["data"]), r))
                domains = []
                for res in results["data"]:
                    ip_str = res["ip_str"]
                    port_str = res["port"]
                    transport = res["transport"]
                    display("Processing IP: {} Port: {}/{}".format(
                        ip_str, port_str, transport))
                    created, IP = self.IPAddress.find_or_create(
                        ip_address=ip_str)
                    IP.meta["shodan_data"] = results

                    created, port = self.Port.find_or_create(
                        ip_address=IP, port_number=port_str, proto=transport)

                    if created:
                        svc = ""

                        if res.get("ssl", False):
                            svc = "https"
                        elif res.get("http", False):
                            svc = "http"

                        else:
                            svc = ""

                        port.service_name = svc
                    port.status = "open"
                    port.meta["shodan_data"] = res
                    port.save()

                    if res.get("ssl", {}).get('cert', {}).get('extensions'):
                        for d in res['ssl']['cert']['extensions']:
                            if d['name'] == 'subjectAltName':

                                domains += get_domains_from_data(d['data'])
                                display(
                                    "Domains discovered in subjectAltName: {}".
                                    format(", ".join(
                                        get_domains_from_data(d['data']))))

                    if res.get("ssl", {}).get('cert', {}).get(
                            'subject', {}).get('CN') and '*' not in res['ssl'][
                                'cert']['subject']['CN']:
                        domains.append(res['ssl']['cert']['subject']['CN'])

                    if res.get('hostnames'):
                        domains += res['hostnames']

                for d in list(set(domains)):
                    display("Adding discovered domain {}".format(d))
                    created, domain = self.Domain.find_or_create(domain=d)
Пример #6
0
class Module(ToolTemplate):

    name = "Whois"
    binary_name = "whois"

    def __init__(self, db):
        self.db = db
        self.BaseDomain = BaseDomainRepository(db, self.name)
        self.ScopeCidr = ScopeCIDRRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()

        self.options.add_argument("-d", "--domain", help="Domain to query")
        self.options.add_argument("-c", "--cidr", help="CIDR to query")

        self.options.add_argument(
            "-s",
            "--rescan",
            help="Rescan domains that have already been scanned",
            action="store_true",
        )
        self.options.add_argument(
            "-a",
            "--all_data",
            help="Scan all data in database, regardless of scope",
            action="store_true",
        )
        self.options.add_argument(
            "-i",
            "--import_database",
            help="Run WHOIS on all domains and CIDRs in database",
            action="store_true",
        )

    def get_targets(self, args):

        targets = []
        if args.domain:

            targets.append({"domain": args.domain, "cidr": ""})

        elif args.cidr:

            targets.append({"domain": "", "cidr": args.cidr.split("/")[0]})

        elif args.import_database:

            if args.all_data:
                scope_type = ""
            else:
                scope_type = "passive"
            if args.rescan:
                domains = self.BaseDomain.all(scope_type=scope_type)
                cidrs = self.ScopeCidr.all()
            else:
                domains = self.BaseDomain.all(scope_type=scope_type, tool=self.name)
                cidrs = self.ScopeCidr.all(tool=self.name)

            for domain in domains:
                targets.append({"domain": domain.domain, "cidr": ""})
            for cidr in cidrs:
                targets.append({"domain": "", "cidr": cidr.cidr.split("/")[0]})

        if args.output_path[0] == "/":
            output_path = os.path.join(
                self.base_config["PROJECT"]["base_path"], args.output_path[1:]
            )
        else:
            output_path = os.path.join(
                self.base_config["PROJECT"]["base_path"], args.output_path
            )

        if not os.path.exists(output_path):
            os.makedirs(output_path)

        for t in targets:
            t["output"] = os.path.join(output_path, t["domain"] + t["cidr"])

        return targets

    def build_cmd(self, args):

        if not args.tool_args:
            args.tool_args = ""
        cmd = (
            'bash -c "'
            + self.binary  # noqa: W503
            + " {domain}{cidr} "  # noqa: W503
            + args.tool_args  # noqa: W503
            + '> {output}" '  # noqa: W503
        )

        return cmd

    def process_output(self, cmds):

        display("Importing data to database")

        for cmd in cmds:
            if cmd["cidr"]:
                _, cidr = self.ScopeCidr.find_or_create(cidr=cmd["cidr"])
                cidr.meta["whois"] = read_file(cmd["output"])
                display(cidr.meta["whois"])
                cidr.update()

            elif cmd["domain"]:
                _, domain = self.BaseDomain.find_or_create(domain=cmd["domain"])
                domain.meta["whois"] = read_file(cmd["output"])
                display(domain.meta["whois"])
                domain.update()

        self.BaseDomain.commit()
Пример #7
0
 def __init__(self, db):
     self.db = db
     self.BaseDomain = BaseDomainRepository(db, self.name)
     self.ScopeCidr = ScopeCIDRRepository(db, self.name)
Пример #8
0
class Module(ModuleTemplate):

    name = "Nessus"

    def __init__(self, db):
        self.db = db
        self.BaseDomain = BaseDomainRepository(db, self.name)
        self.Domain = DomainRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)
        self.Port = PortRepository(db, self.name)
        self.Vulnerability = VulnRepository(db, self.name)
        self.CVE = CVERepository(db, self.name)
        self.ScopeCIDR = ScopeCIDRRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()
        self.options.add_argument(
            "--import_file",
            help=
            "Import separated Nessus files separated by a space. DO NOT USE QUOTES OR COMMAS",
            nargs="+",
        )
        self.options.add_argument(
            "--launch",
            help=
            "Launch Nessus scan using Actively scoped IPs and domains in the database",
            action="store_true",
        )
        self.options.add_argument("--job_name",
                                  help="Job name inside Nessus",
                                  default="Armory Job")
        self.options.add_argument("--username", help="Nessus Username")
        self.options.add_argument("--password", help="Nessus Password")
        self.options.add_argument(
            "--host",
            help="Hostname:Port of Nessus web interface (ie localhost:8835")
        self.options.add_argument("--uuid",
                                  help="UUID of Nessus Policy to run")
        self.options.add_argument("--policy_id", help="Policy ID to use")
        self.options.add_argument("--folder_id",
                                  help="ID for folder to store job in")
        self.options.add_argument(
            "--download",
            help="Download Nessus job from server and import",
            action="store_true",
        )
        self.options.add_argument("--job_id",
                                  help="Job ID to download and import")
        self.options.add_argument(
            "--output_path",
            help="Path to store downloaded file (Default: Nessus)",
            default=self.name,
        )

    def run(self, args):
        if args.import_file:
            for nFile in args.import_file:
                self.process_data(nFile, args)
        elif args.launch:
            if (not args.username  # noqa: W503
                    and not args.password  # noqa: W503
                    and not args.host  # noqa: W503
                    and not args.uuid  # noqa: W503
                    and not args.policy_id  # noqa: W503
                    and not args.folder_id  # noqa: W503
                ):
                display_error(
                    "You must supply a username, password, and host to launch a Nessus job"
                )

            else:
                n = NessusRequest(
                    args.username,
                    args.password,
                    args.host,
                    uuid=args.uuid,
                    policy_id=args.policy_id,
                    folder_id=args.folder_id,
                )

                ips = [
                    ip.ip_address
                    for ip in self.IPAddress.all(scope_type="active",
                                                 tool=self.name)
                ]
                cidrs = [
                    cidr.cidr for cidr in self.ScopeCIDR.all(tool=self.name)
                ]
                domains = [
                    domain.domain
                    for domain in self.Domain.all(scope_type="active",
                                                  tool=self.name)
                ]
                targets = ", ".join(merge_ranges(ips + cidrs) + domains)

                res = n.launch_job(targets, args.job_name)
                display("New Nessus job launched with ID {}".format(res))
                display(
                    "Remember this number! You'll need it to download the job once it is done."
                )

        elif args.download:
            if (not args.username  # noqa: W503
                    and not args.password  # noqa: W503
                    and not args.host  # noqa: W503
                    and not args.job_id  # noqa: W503
                ):
                display_error(
                    "You must supply host, username, password and job_id to download a report to import"
                )

            else:
                n = NessusRequest(
                    args.username,
                    args.password,
                    args.host,
                    proxies={"https": "127.0.0.1:8080"},
                )

                if args.output_path[0] == "/":
                    output_path = os.path.join(
                        self.base_config["PROJECT"]["base_path"],
                        args.output_path[1:])
                else:
                    output_path = os.path.join(
                        self.base_config["PROJECT"]["base_path"],
                        args.output_path)

                if not os.path.exists(output_path):
                    os.makedirs(output_path)
                output_path = os.path.join(
                    output_path,
                    "Nessus-export-{}.nessus".format(int(time.time())))
                n.export_file(args.job_id, output_path)

                self.process_data(output_path, args)

    def nessCheckPlugin(self, tag):
        nessPlugins = [
            "10759",
            "77026",
            "20089",
            "56984",
            "71049",
            "70658",
            "40984",
            "11411",
        ]

        pluginID = tag.get("pluginID")

        if pluginID in nessPlugins:
            # print pluginID + " is in the list"
            if pluginID == "10759":
                if tag.find("plugin_output") is not None:
                    return (
                        tag.find("plugin_output").text.split(
                            "\n\n")[3].strip()
                    )  # returns IP for Web Server HTTP Header INternal IP disclosure
                else:
                    return ""

            if pluginID == "77026":
                if tag.find("plugin_output") is not None:
                    return (
                        tag.find("plugin_output").text.split(
                            "\n\n")[3].strip()
                    )  # Microsoft Exchange Client Access Server Information Disclosure (IP addy)
                else:
                    return ""

            if pluginID == "71049" or pluginID == "70658":
                output = ""
                if tag.find("plugin_output") is not None:
                    tmp = tag.find("plugin_output").text.split(":")[
                        1]  # SSH Weak MAC & CBC Algorithms Enabled
                    # print "#"*5
                    tmp = tmp.split("\n\n")[1].replace("  ", "")
                    # print "#"*5
                    output = tmp.split("\n")
                    # print ", ".join(output)
                return ", ".join(output)

            if pluginID == "56984":
                if tag.find("plugin_output") is not None:
                    tmp = (tag.find("plugin_output").text.split(
                        "This port supports ")[1].strip()
                           )  # SSL / TLS Versions Supported
                    tmp = tmp.split("/")
                    bad = []
                    for i in tmp:
                        # print i
                        if "SSLv" in i:
                            bad.append(i)
                        elif "TLSv1.0" in i:
                            bad.append(i)
                    if bad != []:
                        return ", ".join(bad).rstrip(".")
                    else:
                        return ""
                else:
                    return ""

            if pluginID == "40984":  # browsable web dirs
                if tag.find("plugin_output") is not None:
                    tmp = (tag.find("plugin_output").text.split(
                        "The following directories are browsable :")
                           [1].strip())
                    directories = tmp.split("\n")

                    return "\n".join(directories)

            if pluginID == "11411":  # Backup Files Disclosure
                if tag.find("plugin_output") is not None:
                    urls = []

                    tmp = (tag.find("plugin_output").text.split(
                        "It is possible to read the following backup file")
                           [1].strip())
                    tmpUrls = tmp.split("\n")
                    for url in tmpUrls:
                        if "URL" in url:
                            urls.append(url.split(":")[1].lstrip())
                    if urls:
                        return "\n".join(urls)
                    else:
                        return ""

            if pluginID == "20089":  # F5 cookie
                if tag.find("plugin_output") is not None:
                    f5Output = []
                    cookieVal = []
                    output = tag.find("plugin_output").text.strip().split("\n")
                    for line in output:
                        # print line
                        line = line.split(":")
                        for i, item in enumerate(line):
                            item = item.strip()
                            if "Cookie" in item:
                                line.pop(i)  # Pop to remove the first?
                                tmp = line.pop(i)
                                tmp.strip()
                                cookieVal.append(tmp)
                            else:
                                item = "".join(item)
                                f5Output.append(item)

                    f5Output = " : ".join(f5Output)
                    f5Output = f5Output.replace(" :  : ", ", ")
                    f5Output += " [" + ", ".join(cookieVal) + "]"
                    c = 0
                    tmpF5Output = f5Output.split()
                    for i, letter in enumerate(tmpF5Output):
                        if letter == ":":
                            c += 1
                            if (c % 2) == 0:
                                tmpF5Output[i] = " "
                                return "".join(tmpF5Output).replace("[", " [")
                else:

                    return ""
        else:

            return False

    def getVulns(self, ip, ReportHost):
        """Gets vulns and associated services"""
        for tag in ReportHost.iter("ReportItem"):
            exploitable = False
            cves = []
            vuln_refs = {}
            proto = tag.get("protocol")
            port = tag.get("port")
            svc_name = tag.get("svc_name").replace("?", "")

            tmpPort = proto + "/" + port
            if tmpPort.lower() == "tcp/443":
                portName = "https"
            elif tmpPort.lower() == "tcp/80":
                portName = "http"
            elif svc_name == "www":
                plugin_name = tag.get("pluginName")
                if "tls" in plugin_name.lower() or "ssl" in plugin_name.lower(
                ):
                    portName = "https"
                else:
                    portName = "http"
            else:
                portName = svc_name

            created, db_port = self.Port.find_or_create(port_number=port,
                                                        status="open",
                                                        proto=proto,
                                                        ip_address_id=ip.id)

            if db_port.service_name == "http":
                if portName == "https":
                    db_port.service_name = portName
            elif db_port.service_name == "https":
                pass
            else:
                db_port.service_name = portName
            db_port.save()

            if tag.get("pluginID") == "56984":
                severity = 1
            elif tag.get("pluginID") == "11411":
                severity = 3
            else:
                severity = int(tag.get("severity"))

            findingName = tag.get("pluginName")
            description = tag.find("description").text

            if tag.find(
                    "solution") is not None and tag.find("solution") != "n/a":
                solution = tag.find("solution").text
            else:
                solution = "No Remediation From Nessus"

            nessCheck = self.nessCheckPlugin(tag)

            if nessCheck:
                if not db_port.info:
                    db_port.info = {findingName: nessCheck}
                else:
                    db_port.info[findingName] = nessCheck

                db_port.save()

            if tag.find("exploit_available") is not None:
                exploitable = True

            metasploits = tag.findall("metasploit_name")
            if metasploits:
                vuln_refs["metasploit"] = []
                for tmp in metasploits:
                    vuln_refs["metasploit"].append(tmp.text)

            edb_id = tag.findall("edb-id")
            if edb_id:
                vuln_refs["edb-id"] = []
                for tmp in edb_id:
                    vuln_refs["edb-id"].append(tmp.text)

            tmpcves = tag.findall("cve")
            for c in tmpcves:
                if c.text not in cves:
                    cves.append(c.text)

            if not self.Vulnerability.find(name=findingName):
                created, db_vuln = self.Vulnerability.find_or_create(
                    name=findingName,
                    severity=severity,
                    description=description,
                    remediation=solution,
                )
                db_vuln.ports.append(db_port)
                db_vuln.exploitable = exploitable
                if exploitable:
                    display_new("exploit avalable for " + findingName)

                if vuln_refs:
                    db_vuln.exploit_reference = vuln_refs

            else:
                db_vuln = self.Vulnerability.find(name=findingName)
                db_vuln.ports.append(db_port)
                db_vuln.exploitable = exploitable
                if vuln_refs:
                    if db_vuln.exploit_reference is not None:
                        for key in vuln_refs.keys():
                            if key not in db_vuln.exploit_reference.keys():
                                db_vuln.exploit_reference[key] = vuln_refs[key]
                            else:
                                for ref in vuln_refs[key]:
                                    if ref not in db_vuln.exploit_reference[
                                            key]:
                                        db_vuln.exploit_reference[key].append(
                                            ref)
                    else:
                        db_vuln.exploit_reference = vuln_refs

            for cve in cves:
                if not self.CVE.find(name=cve):
                    try:
                        res = json.loads(
                            requests.get("http://cve.circl.lu/api/cve/%s" %
                                         cve).text)
                        cveDescription = res["summary"]
                        cvss = float(res["cvss"])

                    except Exception:
                        cveDescription = None
                        cvss = None

                    if not self.CVE.find(name=cve):
                        created, db_cve = self.CVE.find_or_create(
                            name=cve,
                            description=cveDescription,
                            temporal_score=cvss,
                        )
                        db_cve.vulnerabilities.append(db_vuln)
                    else:
                        db_cve = self.CVE.find(name=cve)
                        if (db_cve.description is None
                                and cveDescription is not None  # noqa: W503
                            ):
                            db_cve.description = cveDescription
                        if db_cve.temporal_score is None and cvss is not None:
                            db_cve.temporal_score = cvss
                        db_cve.vulnerabilities.append(db_vuln)

    def process_data(self, nFile, args):
        display("Reading " + nFile)
        tree = ET.parse(nFile)
        root = tree.getroot()
        for ReportHost in root.iter("ReportHost"):
            os = []
            hostname = ""
            hostIP = ""
            for HostProperties in ReportHost.iter("HostProperties"):
                for tag in HostProperties:
                    if tag.get("name") == "host-ip":
                        hostIP = tag.text

                    if tag.get("name") == "host-fqdn":
                        hostname = tag.text.lower()
                        hostname = hostname.replace("www.", "")

                    if tag.get("name") == "operating-system":
                        os = tag.text.split("\n")

            if hostIP:  # apparently nessus doesn't always have an IP to work with...

                if hostname:
                    display("Gathering Nessus info for {} ( {} )".format(
                        hostIP, hostname))
                else:
                    display("Gathering Nessus info for {}".format(hostIP))

                created, ip = self.IPAddress.find_or_create(ip_address=hostIP)

                if hostname:
                    created, domain = self.Domain.find_or_create(
                        domain=hostname)

                    if ip not in domain.ip_addresses:
                        ip.save()
                        domain.ip_addresses.append(ip)
                        domain.save()

                if os:
                    for o in os:
                        if not ip.OS:
                            ip.OS = o
                        else:
                            if o not in ip.OS.split(" OR "):
                                ip.OS += " OR " + o

                self.getVulns(ip, ReportHost)
                self.IPAddress.commit()

        return
Пример #9
0
class Module(ToolTemplate):
    """
    Module for running nmap. Make sure to pass all nmap-specific arguments at the end, after --tool_args

    """

    name = "Nmap"
    binary_name = "nmap"

    def __init__(self, db):
        self.db = db
        self.BaseDomain = BaseDomainRepository(db, self.name)
        self.Domain = DomainRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)
        self.Port = PortRepository(db, self.name)

        self.Vulnerability = VulnRepository(db, self.name)
        self.CVE = CVERepository(db, self.name)
        self.ScopeCIDR = ScopeCIDRRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()
        self.options.add_argument(
            "--hosts",
            help="Things to scan separated by a space. DO NOT USE QUOTES OR COMMAS",
            nargs="+",
        )
        self.options.add_argument("--hosts_file", help="File containing hosts")
        self.options.add_argument(
            "-i",
            "--hosts_database",
            help="Use unscanned hosts from the database",
            action="store_true",
        )
        self.options.add_argument(
            "--rescan", help="Overwrite files without asking", action="store_true"
        )
        self.options.add_argument(
            "--filename",
            help="Output filename. By default will use the current timestamp.",
        )
        self.options.set_defaults(timeout=None)
        self.options.add_argument(
            "--import_file", help="Import results from an Nmap XML file."
        )

    def get_targets(self, args):

        if args.import_file:
            args.no_binary = True
            return [{"target": "", "output": args.import_file}]

        targets = []

        if args.hosts:

            if type(args.hosts) == list:
                for h in args.hosts:
                    if check_if_ip(h):
                        targets.append(h)
                    else:
                        created, domain = self.Domain.find_or_create(domain=h)
                        targets += [i.ip_address for i in domain.ip_addresses]

            else:
                if check_if_ip(h):
                    targets.append(h)
                else:
                    created, domain = self.Domain.find_or_create(domain=h)
                    targets += [i.ip_address for i in domain.ip_addresses]

        if args.hosts_database:
            if args.rescan:
                targets += [
                    h.ip_address for h in self.IPAddress.all(scope_type="active")
                ]
                targets += [h.cidr for h in self.ScopeCIDR.all()]
            else:
                targets += [
                    h.ip_address
                    for h in self.IPAddress.all(tool=self.name, scope_type="active")
                ]
                targets += [h.cidr for h in self.ScopeCIDR.all(tool=self.name)]

        if args.hosts_file:
            for h in [l for l in open(args.hosts_file).read().split("\n") if l]:
                if check_if_ip(h):
                    targets.append(h)
                else:
                    created, domain = self.Domain.find_or_create(domain=h)
                    targets += [i.ip_address for i in domain.ip_addresses]

        # Here we should deduplicate the targets, and ensure that we don't have IPs listed that also exist inside CIDRs
        data = []
        for t in targets:
            ips = [str(i) for i in list(IPNetwork(t))]
            data += ips

        _, file_name = tempfile.mkstemp()
        open(file_name, "w").write("\n".join(list(set(data))))

        if args.output_path[0] == "/":
            self.path = os.path.join(
                self.base_config["PROJECT"]["base_path"], args.output_path[1:]
            )
        else:
            self.path = os.path.join(
                self.base_config["PROJECT"]["base_path"], args.output_path
            )

        if not os.path.exists(self.path):
            os.makedirs(self.path)

        if args.filename:
            output_path = os.path.join(self.path, args.filename)
        else:
            output_path = os.path.join(
                self.path,
                "nmap-scan-%s.xml"
                % datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S"),
            )

        return [{"target": file_name, "output": output_path}]

    def build_cmd(self, args):

        command = "sudo " + self.binary + " -oX {output} -iL {target} "

        if args.tool_args:
            command += args.tool_args

        return command

    def process_output(self, cmds):

        self.import_nmap(cmds[0]["output"])
        if cmds[0]["target"]:
            os.unlink(cmds[0]["target"])

    def parseHeaders(self, httpHeaders):
        bsHeaders = [
            "Pragma",
            "Expires",
            "Date",
            "Transfer-Encoding",
            "Connection",
            "X-Content-Type-Options",
            "Cache-Control",
            "X-Frame-Options",
            "Content-Type",
            "Content-Length",
            "(Request type",
        ]
        keepHeaders = {}
        for i in range(0, len(httpHeaders)):
            if httpHeaders[i].strip() != "" and httpHeaders[i].split(":")[
                0
            ].strip() not in " ".join(bsHeaders):
                hName = httpHeaders[i].split(":")[0].strip()
                hValue = "".join(httpHeaders[i].split(":")[1:]).strip()
                keepHeaders[hName] = hValue

        if keepHeaders == {}:
            keepHeaders = ""

        return keepHeaders

    def import_nmap(
        self, filename
    ):  # domains={}, ips={}, rejects=[] == temp while no db
        nFile = filename

        try:
            tree = ET.parse(nFile)
            root = tree.getroot()
            hosts = root.findall("host")

        except Exception:
            print(nFile + " doesn't exist somehow...skipping")
            return

        for host in hosts:
            hostIP = host.find("address").get("addr")

            created, ip = self.IPAddress.find_or_create(ip_address=hostIP)

            for hostname in host.findall("hostnames/hostname"):
                hostname = hostname.get("name")
                hostname = hostname.lower().replace("www.", "")

                # reHostname = re.search(
                #     r"\d{1,3}\-\d{1,3}\-\d{1,3}\-\d{1,3}", hostname
                # )  # attempt to not get PTR record
                # if not reHostname:

                created, domain = self.Domain.find_or_create(domain=hostname)
                if ip not in domain.ip_addresses:
                    domain.ip_addresses.append(ip)
                    domain.save()

            for port in host.findall("ports/port"):

                if port.find("state").get("state"):
                    portState = port.find("state").get("state")
                    hostPort = port.get("portid")
                    portProto = port.get("protocol")

                    created, db_port = self.Port.find_or_create(
                        port_number=hostPort,
                        status=portState,
                        proto=portProto,
                        ip_address=ip,
                    )

                    if port.find("service") is not None:
                        portName = port.find("service").get("name")
                        if portName == "http" and hostPort == "443":
                            portName = "https"
                    else:
                        portName = "Unknown"

                    if created:
                        db_port.service_name = portName
                    info = db_port.info
                    if not info:
                        info = {}

                    for script in port.findall(
                        "script"
                    ):  # just getting commonName from cert
                        if script.get("id") == "ssl-cert":
                            db_port.cert = script.get("output")
                            cert_domains = self.get_domains_from_cert(
                                script.get("output")
                            )

                            for hostname in cert_domains:
                                hostname = hostname.lower().replace("www.", "")
                                created, domain = self.Domain.find_or_create(
                                    domain=hostname
                                )
                                if created:
                                    print("New domain found: %s" % hostname)

                        elif script.get("id") == "vulners":
                            print(
                                "Gathering vuln info for {} : {}/{}\n".format(
                                    hostIP, portProto, hostPort
                                )
                            )
                            self.parseVulners(script.get("output"), db_port)

                        elif script.get("id") == "banner":
                            info["banner"] = script.get("output")

                        elif script.get("id") == "http-headers":

                            httpHeaders = script.get("output")
                            httpHeaders = httpHeaders.strip().split("\n")
                            keepHeaders = self.parseHeaders(httpHeaders)
                            info["http-headers"] = keepHeaders

                        elif script.get("id") == "http-auth":
                            info["http-auth"] = script.get("output")

                        elif script.get("id") == "http-title":
                            info["http-title"] = script.get("output")

                    db_port.info = info
                    db_port.save()

            self.IPAddress.commit()

    def parseVulners(self, scriptOutput, db_port):
        urls = re.findall(r"(https://vulners.com/cve/CVE-\d*-\d*)", scriptOutput)
        for url in urls:
            vuln_refs = []
            exploitable = False
            cve = url.split("/cve/")[1]
            vulners = requests.get("https://vulners.com/cve/%s" % cve).text
            exploitdb = re.findall(
                r"https://www.exploit-db.com/exploits/\d{,7}", vulners
            )
            for edb in exploitdb:
                exploitable = True

                if edb.split("/exploits/")[1] not in vuln_refs:
                    vuln_refs.append(edb.split("/exploits/")[1])

            if not self.CVE.find(name=cve):
                # print "Gathering CVE info for", cve
                try:
                    res = json.loads(
                        requests.get("http://cve.circl.lu/api/cve/%s" % cve).text
                    )
                    cveDescription = res["summary"]
                    cvss = float(res["cvss"])
                    findingName = res["oval"][0]["title"]
                    if int(cvss) <= 3:
                        severity = 1

                    elif (int(cvss) / 2) == 5:
                        severity = 4

                    else:
                        severity = int(cvss) / 2

                    if not self.Vulnerability.find(name=findingName):
                        # print "Creating", findingName
                        created, db_vuln = self.Vulnerability.find_or_create(
                            name=findingName,
                            severity=severity,
                            description=cveDescription,
                        )
                        db_vuln.ports.append(db_port)
                        db_vuln.exploitable = exploitable
                        if vuln_refs:
                            db_vuln.exploit_reference = {"edb-id": vuln_refs}
                            db_vuln.save()

                    else:
                        # print "modifying",findingName
                        db_vuln = self.Vulnerability.find(name=findingName)
                        db_vuln.ports.append(db_port)
                        db_vuln.exploitable = exploitable

                        if vuln_refs:
                            db_vuln.exploitable = exploitable
                            if db_vuln.exploit_reference is not None:
                                if "edb-id" in db_vuln.exploit_reference:
                                    for ref in vuln_refs:
                                        if (
                                            ref
                                            not in db_vuln.exploit_reference["edb-id"]
                                        ):
                                            db_vuln.exploit_reference["edb-id"].append(
                                                ref
                                            )

                                else:
                                    db_vuln.exploit_reference["edb-id"] = vuln_refs
                            else:
                                db_vuln.exploit_reference = {"edb-id": vuln_refs}

                        db_vuln.save()

                    if not self.CVE.find(name=cve):
                        created, db_cve = self.CVE.find_or_create(
                            name=cve, description=cveDescription, temporal_score=cvss
                        )
                        db_cve.vulnerabilities.append(db_vuln)
                        db_cve.save()

                    else:
                        db_cve = self.CVE.find(name=cve)
                        db_cve.vulnerabilities.append(db_vuln)
                        db_cve.save()

                    self.Vulnerability.commit()
                    self.CVE.commit()
                except Exception:
                    print("something went wrong with the vuln/cve info gathering")
                    if vulners:
                        print(
                            "Vulners report was found but no exploit-db was discovered"
                        )
                        # "Affected vulners items"
                        # print vulners
                    print("Affected CVE")
                    print(cve)
                    pass

            else:
                db_cve = self.CVE.find(name=cve)
                for db_vulns in db_cve.vulnerabilities:
                    if db_port not in db_vulns.ports:
                        db_vulns.ports.append(db_port)
        return

    def get_domains_from_cert(self, cert):
        # Shamelessly lifted regex from stack overflow
        regex = r"(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}"

        domains = list(set([d for d in re.findall(regex, cert) if "*" not in d]))

        return domains
Пример #10
0
class Module(ModuleTemplate):
    """
    Ingests domains and IPs. Domains get ip info and cidr info, and IPs get
    CIDR info.

    """

    name = "Ingestor"

    def __init__(self, db):
        self.db = db
        self.BaseDomain = BaseDomainRepository(db, self.name)
        self.Domain = DomainRepository(db, self.name)
        self.IPAddress = IPRepository(db, self.name)
        self.CIDR = CIDRRepository(db, self.name)
        self.ScopeCIDR = ScopeCIDRRepository(db, self.name)

    def set_options(self):
        super(Module, self).set_options()

        self.options.add_argument(
            "-d",
            "--import_domains",
            help="Either domain to import or file containing domains to import. One per line",
        )
        self.options.add_argument(
            "-i",
            "--import_ips",
            help="Either IP/range to import or file containing IPs and ranges, one per line.",
        )
        self.options.add_argument(
            "-a",
            "--active",
            help="Set scoping on imported data as active",
            action="store_true",
        )
        self.options.add_argument(
            "-p",
            "--passive",
            help="Set scoping on imported data as passive",
            action="store_true",
        )
        self.options.add_argument(
            "-sc",
            "--scope_cidrs",
            help="Cycle through out of scope networks and decide if you want to add them in scope",
            action="store_true",
        )
        self.options.add_argument(
            "-sb",
            "--scope_base_domains",
            help="Cycle through out of scope base domains and decide if you want to add them in scope",
            action="store_true",
        )
        self.options.add_argument("--descope", help="Descope an IP, domain, or CIDR")
        self.options.add_argument(
            "-Ii",
            "--import_database_ips",
            help="Import IPs from database",
            action="store_true",
        )
        self.options.add_argument(
            "--force",
            help="Force processing again, even if already processed",
            action="store_true",
        )

    def run(self, args):

        self.in_scope = args.active
        self.passive_scope = args.passive

        if args.descope:
            if "/" in args.descope:
                self.descope_cidr(args.descope)
            elif check_string(args.descope):
                pass

            else:
                self.descope_ip(args.descope)

                # Check if in ScopeCIDR and remove if found

        if args.import_ips:
            try:
                ips = open(args.import_ips)
                for line in ips:

                    if line.strip():
                        if "/" in line or "-" in line:
                            self.process_cidr(line)

                        else:
                            self.process_ip(line.strip(), force_scope=True)
                        self.Domain.commit()
            except IOError:

                if "/" in args.import_ips or "-" in args.import_ips:
                    self.process_cidr(args.import_ips)

                else:
                    self.process_ip(args.import_ips.strip(), force_scope=True)
                self.Domain.commit()

        if args.import_domains:
            try:
                domains = open(args.import_domains)
                for line in domains:
                    if line.strip():
                        self.process_domain(line.strip())
                        self.Domain.commit()
            except IOError:
                self.process_domain(args.import_domains.strip())
                self.Domain.commit()

        if args.scope_base_domains:
            base_domains = self.BaseDomain.all(in_scope=False, passive_scope=False)

            for bd in base_domains:
                self.reclassify_domain(bd)

            self.BaseDomain.commit()

    def get_domain_ips(self, domain):
        ips = []
        try:
            answers = dns.resolver.query(domain, "A")
            for a in answers:
                ips.append(a.address)
            return ips
        except Exception:
            return []

    def process_domain(self, domain_str):

        created, domain = self.Domain.find_or_create(
            only_tool=True,
            domain=domain_str,
            in_scope=self.in_scope,
            passive_scope=self.passive_scope,
        )
        if not created:
            if (
                domain.in_scope != self.in_scope
                or domain.passive_scope != self.passive_scope  # noqa: W503
            ):
                display(
                    "Domain %s already exists with different scoping. Updating to Active Scope: %s Passive Scope: %s"
                    % (domain_str, self.in_scope, self.passive_scope)
                )

                domain.in_scope = self.in_scope
                domain.passive_scope = self.passive_scope
                domain.update()

                if domain.base_domain.domain == domain.domain:
                    display("Name also matches a base domain. Updating that as well.")
                    domain.base_domain.in_scope = self.in_scope
                    domain.base_domain.passive_scope = self.passive_scope
                    domain.base_domain.update()

    def process_ip(self, ip_str, force_scope=True):

        created, ip = self.IPAddress.find_or_create(
            only_tool=True,
            ip_address=ip_str,
            in_scope=self.in_scope,
            passive_scope=self.passive_scope,
        )
        if not created:
            if ip.in_scope != self.in_scope or ip.passive_scope != self.passive_scope:
                display(
                    "IP %s already exists with different scoping. Updating to Active Scope: %s Passive Scope: %s"
                    % (ip_str, self.in_scope, self.passive_scope)
                )

                ip.in_scope = self.in_scope
                ip.passive_scope = self.passive_scope
                ip.update()
        return ip

    def process_cidr(self, line):
        display("Processing %s" % line)
        if "/" in line:
            created, cidr = self.ScopeCIDR.find_or_create(cidr=line.strip())
            if created:
                display_new("Adding %s to scoped CIDRs in database" % line.strip())
                cidr.in_scope = True
                cidr.update()

        elif "-" in line:
            start_ip, end_ip = line.strip().replace(" ", "").split("-")
            if "." not in end_ip:
                end_ip = ".".join(start_ip.split(".")[:3] + [end_ip])

            cidrs = iprange_to_cidrs(start_ip, end_ip)

            for c in cidrs:

                created, cidr = self.ScopeCIDR.find_or_create(cidr=str(c))
                if created:
                    display_new("Adding %s to scoped CIDRs in database" % line.strip())
                    cidr.in_scope = True
                    cidr.update()

    def reclassify_domain(self, bd):
        if bd.meta.get("whois", False):
            display_new("Whois data found for {}".format(bd.domain))
            print(bd.meta["whois"])
            res = six.input(
                "Should this domain be scoped (A)ctive, (P)assive, or (N)ot? [a/p/N] "
            )
            if res.lower() == "a":
                bd.in_scope = True
                bd.passive_scope = True

            elif res.lower() == "p":
                bd.in_scope = False
                bd.passive_scope = True
            else:
                bd.in_scope = False
                bd.passive_scope = False
            bd.save()
        else:
            display_error(
                "Unfortunately, there is no whois information for {}. Please populate it using the Whois module".format(
                    bd.domain
                )
            )

    def descope_ip(self, ip):
        ip = self.IPAddress.all(ip_address=ip)
        if ip:
            for i in ip:
                display("Removing IP {} from scope".format(i.ip_address))
                i.in_scope = False
                i.passive_scope = False
                i.update()
                for d in i.domains:
                    in_scope_ips = [
                        ipa
                        for ipa in d.ip_addresses
                        if ipa.in_scope or ipa.passive_scope
                    ]
                    if not in_scope_ips:
                        display(
                            "Domain {} has no more scoped IPs. Removing from scope.".format(
                                d.domain
                            )
                        )
                        d.in_scope = False
                        d.passive_scope = False
            self.IPAddress.commit()

    def descope_cidr(self, cidr):
        CIDR = self.ScopeCIDR.all(cidr=cidr)
        if CIDR:
            for c in CIDR:
                display("Removing {} from ScopeCIDRs".format(c.cidr))
                c.delete()
        cnet = IPNetwork(cidr)
        for ip in self.IPAddress.all():
            if IPAddress(ip.ip_address) in cnet:

                self.descope_ip(ip.ip_address)