예제 #1
0
class VirusTotalConnector:
    """VirusTotal connector."""

    _CONNECTOR_RUN_INTERVAL_SEC = 60 * 60
    _API_URL = "https://www.virustotal.com/api/v3"

    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = Path(
            __file__).parent.parent.resolve() / "config.yml"

        config = (yaml.load(open(config_file_path, encoding="utf-8"),
                            Loader=yaml.FullLoader)
                  if config_file_path.is_file() else {})
        self.helper = OpenCTIConnectorHelper(config)
        token = get_config_variable("VIRUSTOTAL_TOKEN",
                                    ["virustotal", "token"], config)
        self.max_tlp = get_config_variable("VIRUSTOTAL_MAX_TLP",
                                           ["virustotal", "max_tlp"], config)

        self.client = VirusTotalClient(self._API_URL, token)

        # Cache to store YARA rulesets.
        self.yara_cache = {}

    def _create_yara_indicator(
            self,
            yara: dict,
            valid_from: Optional[int] = None) -> Optional[Indicator]:
        """Create an indicator containing the YARA rule from VirusTotal."""
        valid_from_date = (datetime.datetime.min if valid_from is None else
                           datetime.datetime.utcfromtimestamp(valid_from))
        ruleset_id = yara.get("ruleset_id", "No ruleset id provided")
        self.helper.log_info(f"[VirusTotal] Retrieving ruleset {ruleset_id}")

        # Lookup in the cache for the ruleset id, otherwise, request VirusTotal API.
        if ruleset_id in self.yara_cache:
            self.helper.log_debug(
                f"Retrieving YARA ruleset {ruleset_id} from cache.")
            ruleset = self.yara_cache[ruleset_id]
        else:
            self.helper.log_debug(
                f"Retrieving YARA ruleset {ruleset_id} from API.")
            ruleset = self.client.get_yara_ruleset(ruleset_id)
            self.yara_cache[ruleset_id] = ruleset

        # Parse the rules to find the correct one.
        parser = plyara.Plyara()
        rules = parser.parse_string(ruleset["data"]["attributes"]["rules"])
        rule_name = yara.get("rule_name", "No ruleset name provided")
        rule = [r for r in rules if r["rule_name"] == rule_name]
        if len(rule) == 0:
            self.helper.log_warning(f"No YARA rule for rule name {rule_name}")
            return None

        return self.helper.api.indicator.create(
            name=yara.get("rule_name", "No rulename provided"),
            description=json.dumps({
                "description":
                yara.get("description", "No description provided"),
                "author":
                yara.get("author", "No author provided"),
                "source":
                yara.get("source", "No source provided"),
                "ruleset_id":
                ruleset_id,
                "ruleset_name":
                yara.get("ruleset_name", "No ruleset name provided"),
            }),
            pattern=plyara.utils.rebuild_yara_rule(rule[0]),
            pattern_type="yara",
            valid_from=self.helper.api.stix2.format_date(valid_from_date),
            x_opencti_main_observable_type="StixFile",
        )

    def _process_file(self, observable):
        json_data = self.client.get_file_info(observable["observable_value"])
        if "error" in json_data:
            if json_data["error"]["message"] == "Quota exceeded":
                self.helper.log_info("Quota reached, waiting 1 hour.")
                sleep(self._CONNECTOR_RUN_INTERVAL_SEC)
            elif "not found" in json_data["error"]["message"]:
                self.helper.log_info("File not found on VirusTotal.")
                return "File not found on VirusTotal."
            else:
                raise ValueError(json_data["error"]["message"])
        if "data" in json_data:
            data = json_data["data"]
            attributes = data["attributes"]
            # Update the current observable
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=observable["id"],
                input={
                    "key": "hashes.MD5",
                    "value": attributes["md5"]
                },
            )
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=final_observable["id"],
                input={
                    "key": "hashes.SHA-1",
                    "value": attributes["sha1"]
                },
            )
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=final_observable["id"],
                input={
                    "key": "hashes.SHA-256",
                    "value": attributes["sha256"]
                },
            )
            if observable["entity_type"] == "StixFile":
                self.helper.api.stix_cyber_observable.update_field(
                    id=final_observable["id"],
                    input={
                        "key": "size",
                        "value": str(attributes["size"])
                    },
                )
                if observable["name"] is None and len(attributes["names"]) > 0:
                    self.helper.api.stix_cyber_observable.update_field(
                        id=final_observable["id"],
                        input={
                            "key": "name",
                            "value": attributes["names"][0]
                        },
                    )
                    del attributes["names"][0]

            if len(attributes["names"]) > 0:
                self.helper.api.stix_cyber_observable.update_field(
                    id=final_observable["id"],
                    input={
                        "key": "x_opencti_additional_names",
                        "value": attributes["names"],
                    },
                )

            # Create external reference
            external_reference = self.helper.api.external_reference.create(
                source_name="VirusTotal",
                url="https://www.virustotal.com/gui/file/" +
                attributes["sha256"],
                description=attributes["magic"],
            )

            # Create tags
            for tag in attributes["tags"]:
                tag_vt = self.helper.api.label.create(value=tag,
                                                      color="#0059f7")
                self.helper.api.stix_cyber_observable.add_label(
                    id=final_observable["id"], label_id=tag_vt["id"])

            self.helper.api.stix_cyber_observable.add_external_reference(
                id=final_observable["id"],
                external_reference_id=external_reference["id"],
            )

            if "crowdsourced_yara_results" in attributes:
                self.helper.log_info(
                    "[VirusTotal] adding yara results to file.")

                # Add YARA rules (only if a rule is given).
                yaras = list(
                    filter(
                        None,
                        [
                            self._create_yara_indicator(
                                yara, attributes.get("creation_date", None))
                            for yara in attributes["crowdsourced_yara_results"]
                        ],
                    ))

                self.helper.log_debug(
                    f"[VirusTotal] Indicators created: {yaras}")

                # Create the relationships (`related-to`) between the yaras and the file.
                for yara in yaras:
                    self.helper.api.stix_core_relationship.create(
                        fromId=final_observable["id"],
                        toId=yara["id"],
                        relationship_type="related-to",
                    )

            return "File found on VirusTotal, knowledge attached."

    def _process_message(self, data):
        entity_id = data["entity_id"]
        observable = self.helper.api.stix_cyber_observable.read(id=entity_id)
        # Extract TLP
        tlp = "TLP:WHITE"
        for marking_definition in observable["objectMarking"]:
            if marking_definition["definition_type"] == "TLP":
                tlp = marking_definition["definition"]
        if not OpenCTIConnectorHelper.check_max_tlp(tlp, self.max_tlp):
            raise ValueError(
                "Do not send any data, TLP of the observable is greater than MAX TLP"
            )
        return self._process_file(observable)

    def start(self):
        """Start the main loop."""
        self.helper.listen(self._process_message)
예제 #2
0
class RiskIQConnector:
    """RiskIQ Connector main class."""

    _DEFAULT_AUTHOR = "RiskIQ"

    # Default run interval
    _CONNECTOR_RUN_INTERVAL_SEC = 60
    _STATE_LATEST_RUN_TIMESTAMP = "latest_run_timestamp"

    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = Path(
            __file__).parent.parent.resolve() / "config.yml"

        config = (yaml.load(open(config_file_path, encoding="utf8"),
                            Loader=yaml.FullLoader)
                  if config_file_path.is_file() else {})

        self.helper = OpenCTIConnectorHelper(config)

        self.base_url = get_config_variable("RISKIQ_BASE_URL",
                                            ["riskiq", "base_url"], config)
        self.interval_sec = get_config_variable("RISKIQ_INTERVAL_SEC",
                                                ["riskiq", "interval_sec"],
                                                config)
        user = get_config_variable("RISKIQ_USER", ["riskiq", "user"], config)
        password = get_config_variable("RISKIQ_PASSWORD",
                                       ["riskiq", "password"], config)
        # Create the author for all reports.
        self.author = Identity(
            name=self._DEFAULT_AUTHOR,
            identity_class="organization",
            description=
            " RiskIQ is a cyber security company based in San Francisco, California."
            " It provides cloud - based software as a service(SaaS) for organizations"
            " to detect phishing, fraud, malware, and other online security threats.",
            confidence=self.helper.connect_confidence_level,
        )
        # Initialization of the client
        self.client = RiskIQClient(self.base_url, user, password)

    @staticmethod
    def _current_unix_timestamp() -> int:
        return int(time.time())

    def _get_interval(self) -> int:
        return int(self.interval_sec)

    @staticmethod
    def _get_state_value(state: Optional[Mapping[str, Any]],
                         key: str,
                         default: Optional[Any] = None) -> Any:
        if state is not None:
            return state.get(key, default)
        return default

    def _initiate_work(self, timestamp: int) -> str:
        now = datetime.datetime.utcfromtimestamp(timestamp)
        friendly_name = "RiskIQ run @ " + now.strftime("%Y-%m-%d %H:%M:%S")
        work_id = self.helper.api.work.initiate_work(self.helper.connect_id,
                                                     friendly_name)
        self.helper.log_info(f"[RiskIQ] workid {work_id} initiated")
        return work_id

    def _is_scheduled(self, last_run: Optional[int],
                      current_time: int) -> bool:
        if last_run is None:
            self.helper.log_info("RiskIQ connector clean run")
            return True

        time_diff = current_time - last_run
        return time_diff >= self._get_interval()

    def _get_next_interval(self, run_interval: int, timestamp: int,
                           last_run: int) -> int:
        """Get the delay for the next interval."""
        next_run = self._get_interval() - (timestamp - last_run)
        return min(run_interval, next_run)

    def _load_state(self) -> dict[str, Any]:
        current_state = self.helper.get_state()
        if not current_state:
            return {}
        return current_state

    @classmethod
    def _sleep(cls, delay_sec: Optional[int] = None) -> None:
        sleep_delay = (delay_sec if delay_sec is not None else
                       cls._CONNECTOR_RUN_INTERVAL_SEC)
        time.sleep(sleep_delay)

    def run(self):
        """Run RiskIQ connector."""
        self.helper.log_info("Starting RiskIQ connector...")

        while True:
            self.helper.log_info("Running RiskIQ connector...")
            run_interval = self._CONNECTOR_RUN_INTERVAL_SEC

            try:
                self.helper.log_info(f"Connector interval sec: {run_interval}")
                timestamp = self._current_unix_timestamp()
                current_state = self._load_state()
                self.helper.log_info(f"[RiskIQ] loaded state: {current_state}")

                last_run = self._get_state_value(
                    current_state, self._STATE_LATEST_RUN_TIMESTAMP)
                if self._is_scheduled(last_run, timestamp):
                    work_id = self._initiate_work(timestamp)
                    new_state = current_state.copy()
                    last_article = self._get_state_value(
                        current_state,
                        ArticleImporter._LATEST_ARTICLE_TIMESTAMP)

                    self.helper.log_info(f"[RiskIQ] last run: {last_run}")
                    last_article_date = (
                        timestamp_to_datetime(last_article).date()
                        if last_run else None)
                    self.helper.log_debug(
                        f"[RiskIQ] retrieving data from {last_article_date}")
                    response = self.client.get_articles(last_article_date)

                    if self.client.is_correct(response):
                        for article in response["articles"]:
                            importer = ArticleImporter(self.helper, article,
                                                       self.author)
                            importer_state = importer.run(
                                work_id, current_state)
                            if importer_state:
                                self.helper.log_info(
                                    f"[RiskIQ] Updating state {importer_state}"
                                )
                                new_state.update(importer_state)

                            # Set the new state
                            new_state[
                                self.
                                _STATE_LATEST_RUN_TIMESTAMP] = self._current_unix_timestamp(
                                )
                            self.helper.log_info(
                                f"[RiskIQ] Storing new state: {new_state}")
                            self.helper.set_state(new_state)
                    else:
                        self.helper.log_warning(
                            "[RiskIQ] failed to retrieve articles")
                        run_interval = self._CONNECTOR_RUN_INTERVAL_SEC
                        self.helper.log_info(
                            f"[RiskIQ] next run in {run_interval} seconds")
                else:
                    run_interval = self._get_next_interval(
                        run_interval, timestamp, last_run)
                    self.helper.log_info(
                        f"[RiskIQ] Connector will not run, next run in {run_interval} seconds"
                    )

                self._sleep(delay_sec=run_interval)
            except (KeyboardInterrupt, SystemExit):
                self.helper.log_info("RiskIQ connector stop")
                sys.exit(0)
            except Exception as e:
                self.helper.log_error(str(e))
                sys.exit(0)
예제 #3
0
class IvreConnector:
    """The conector object. Instanciate and .start()."""
    def __init__(self):
        """Instantiate the connector helper from the configuration"""
        config_file_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "config.yml")
        if os.path.isfile(config_file_path):
            with open(config_file_path, encoding="utf-8") as fdesc:
                config = yaml.load(fdesc, Loader=yaml.FullLoader)
        else:
            config = {}
        self.helper = OpenCTIConnectorHelper(config)
        self.use_data = get_config_variable("IVRE_USE_DATA",
                                            ["ivre", "use_data"],
                                            default=True)
        self.use_passive = get_config_variable("IVRE_USE_PASSIVE",
                                               ["ivre", "use_passive"],
                                               default=True)
        self.use_passive_as = get_config_variable("IVRE_USE_PASSIVE_AS",
                                                  ["ivre", "use_passive_as"],
                                                  default=True)
        self.use_passive_domain = get_config_variable(
            "IVRE_USE_PASSIVE_DOMAIN", ["ivre", "use_passive_domain"],
            default=True)
        self.use_scans = get_config_variable("IVRE_USE_SCANS",
                                             ["ivre", "use_scans"],
                                             default=True)
        self.use_scans_as = get_config_variable("IVRE_USE_SCANS_AS",
                                                ["ivre", "use_scans_as"],
                                                default=False)
        self.use_scans_domain = get_config_variable(
            "IVRE_USE_SCANS_DOMAIN", ["ivre", "use_scans_domain"],
            default=False)
        self.dbase = MetaDB(
            get_config_variable(
                "IVRE_DB_URL",
                ["ivre", "db_url"],
                config,
                default=ivre_config.DB if hasattr(ivre_config, "DB") else None,
            ),
            urls={
                attr: url
                for attr, url in ((
                    attr,
                    get_config_variable(
                        f"IVRE_DB_URL_{name.upper()}",
                        ["ivre", f"db_url_{name}"],
                        config,
                        default=(getattr(ivre_config, f"DB_{attr.upper()}")
                                 if hasattr(ivre_config, f"DB_{attr.upper()}"
                                            ) else None),
                    ),
                ) for name, attr in DATABASES) if url
            },
        )
        self.databases = {
            name: getattr(self.dbase, attr)
            for name, attr in DATABASES
        }
        self.ivre_instance_name = get_config_variable("CONNECTOR_NAME",
                                                      ["connector", "name"],
                                                      config,
                                                      default="IVRE")
        self.confidence = int(self.helper.connect_confidence_level)
        self.max_tlp = get_config_variable("IVRE_MAX_TLP", ["ivre", "max_tlp"],
                                           config)

    @property
    def ivre_entity(self):
        """This property is used to create an organization for the IVRE
        instance, and returns its id.

        """
        try:
            return self._ivre_entity
        except AttributeError:
            pass
        ivre_entity = self.helper.api.stix_domain_object.get_by_stix_id_or_name(
            name=self.ivre_instance_name)
        if not ivre_entity:
            self.helper.log_info(f"Creating entity {self.ivre_instance_name}")
            self._ivre_entity = self.helper.api.identity.create(
                type="Organization",
                name=self.ivre_instance_name,
                description=
                f"IVRE instance {self.ivre_instance_name}\nSee <https://ivre.rocks/>",
            )["id"]
        else:
            self._ivre_entity = ivre_entity["id"]
        return self._ivre_entity

    def add_asn(self, asnum, asname=None):
        """Given an AS number and optionally an AS name, creates an observable
        and return its ID.

        """
        return self.helper.api.stix_cyber_observable.create(
            observableData={
                "type": TYPE_AS.lower(),
                "number": asnum,
                "name": asname or f"AS{asnum}",
            },
            update=True,
        )["id"]

    def add_addr(self, addr):
        """Given an IP address, creates an observable and returns its ID."""
        return self.helper.api.stix_cyber_observable.create(
            observableData={
                "type":
                (TYPE_IPV6_ADDR if ":" in addr else TYPE_IPV4_ADDR).lower(),
                "value": addr,
            },
            update=True,
        )["id"]

    def add_mac(self, addr):
        """Given a MAC address, creates an observable and returns its ID."""
        return self.helper.api.stix_cyber_observable.create(
            observableData={
                "type": TYPE_MAC_ADDR,
                "value": addr
            },
            update=True,
        )["id"]

    def add_domain(self, name):
        """Given a domain name, creates an observable and returns its ID."""
        return self.helper.api.stix_cyber_observable.create(
            observableData={
                "type": TYPE_DOMAIN.lower(),
                "value": name.lower(),
            },
            update=True,
        )["id"]

    def add_country(self, name, code):
        return self.helper.api.location.create(
            name=name,
            type="Country",
            country=name,
            custom_properties={
                "x_opencti_location_type": "Country",
                "x_opencti_aliases": [name, code],
            },
        )["id"]

    def add_city(self, city_name, country_name, country_code):
        country_id = self.add_country(country_name, country_code)
        city_id = self.helper.api.location.create(
            name=city_name,
            type="City",
            country=country_name,
            custom_properties={"x_opencti_location_type": "City"},
        )["id"]
        self.link_core(city_id, country_id, rel_type="located-at")
        return city_id

    def add_and_link_cert(self, cert, obs_id, firstseen, lastseen):
        """Given a parsed certificate (content of the "infos" field in the
        passive database, or the "ssl-cert" structured script output in the
        scans database), the observable id, the firstseen and lastseen values
        (as datetime.datetime instances), produce the object and the
        relationship between the observable and that object.

        """
        data = {
            "type": TYPE_CERT.lower(),
            # "serial_number": xxx,
            # "version": xxx,
            "hashes": {
                key: cert[key]
                for key in ["md5", "sha1", "sha256"] if key in cert
            },
            "is_self_signed": cert["self_signed"],
        }
        for fld in ["not_after", "not_before"]:
            if fld in cert:
                data[f"validity_{fld}"] = cert[fld].strftime(
                    "%Y-%m-%dT%H:%M:%SZ")
        for fld in ["issuer", "subject"]:
            if f"{fld}_text" in cert:
                data[fld] = cert[f"{fld}_text"].replace("/", ", ")
        if "pubkey" in cert:
            pubkey = cert["pubkey"]
            if "type" in pubkey:
                data["subject_public_key_algorithm"] = pubkey["type"]
            if "exponent" in pubkey:
                data["subject_public_key_exponent"] = pubkey["exponent"]
            if "modulus" in pubkey:
                data["subject_public_key_modulus"] = pubkey["modulus"]
        cert_id = self.helper.api.stix_cyber_observable.create(
            observableData=data,
            update=True,
        )["id"]
        self.link_cyber(obs_id, cert_id, firstseen, lastseen)
        if not cert.get("self_signed") and all(
                TOR_CERT_SUBJECT.search(cert.get(f"{fld}_text", ""))
                for fld in ["issuer", "subject"]):
            self.add_and_link_label(
                "Possible TOR Node",
                obs_id,
                color="#7e4ec2",
            )
            self.add_and_link_label(
                "Possible TOR Certificate",
                cert_id,
                color="#7e4ec2",
            )

    def add_and_link_label(self, value, obs_id, color="#ffffff"):
        label_id = self.helper.api.label.create(value=value, color=color)["id"]
        self.helper.api.stix_cyber_observable.add_label(id=obs_id,
                                                        label_id=label_id)

    def link_cyber(self,
                   from_id,
                   to_id,
                   firstseen,
                   lastseen,
                   rel_type="x_opencti_linked-to"):
        self.helper.api.stix_cyber_observable_relationship.create(
            fromId=from_id,
            toId=to_id,
            createdBy=self.ivre_entity,
            relationship_type=rel_type,
            update=True,
            confidence=self.confidence,
            start_time=firstseen.strftime("%Y-%m-%dT%H:%M:%SZ"),
            stop_time=lastseen.strftime("%Y-%m-%dT%H:%M:%SZ"),
        )

    def link_core(self, from_id, to_id, rel_type="related-to"):
        self.helper.api.stix_core_relationship.create(
            fromId=from_id,
            toId=to_id,
            createdBy=self.ivre_entity,
            relationship_type=rel_type,
            update=True,
            confidence=self.confidence,
        )

    def link_domain_parent(self, domain, parent, parent_id):
        """Link a domain to one of its parent, creating all the sub-domains
        needed. The caller **has** to make sure that `domain` is a subdomain
        of `parent`! Returns the ID of the observable created for `domain`.

        """
        subdomain = domain[:-(len(parent) + 1)]
        domain_id = cur_dom = self.add_domain(domain)
        while "." in subdomain:
            subdomain = subdomain.split(".", 1)[1]
            next_dom = self.add_domain(f"{subdomain}.{parent}")
            self.link_core(cur_dom, next_dom)
            cur_dom = next_dom
        self.link_core(cur_dom, parent_id)
        return domain_id

    def process_scans_record(self, record, observable):
        """Process a `record` from the scans (nmap) purpose; the query was
        made based on `observable`.

        """
        obs_id = observable["id"]
        obs_type = observable["entity_type"]
        firstseen = record.get("starttime", record.get("endtime"))
        lastseen = record.get("endtime", record.get("starttime"))
        if obs_type in TYPES_IP_ADDR:
            addr_id = obs_id
        else:
            addr_id = self.add_addr(record["addr"])
        if obs_type == TYPE_DOMAIN:
            obs_name = observable["value"].lower().rstrip(".")
        for hname in record.get("hostnames", []):
            if hname["type"] in {"A",
                                 "PTR"}:  # Should we add all the hostnames?
                name = hname["name"].lower().rstrip(".")
                if obs_type == TYPE_DOMAIN:
                    if name == obs_name:
                        self.link_cyber(obs_id,
                                        addr_id,
                                        firstseen,
                                        lastseen,
                                        rel_type="resolves-to")
                        continue
                    if name.endswith(f".{obs_name}"):
                        new_obs_id = self.link_domain_parent(
                            name, obs_name, obs_id)
                        self.link_cyber(
                            new_obs_id,
                            addr_id,
                            firstseen,
                            lastseen,
                            rel_type="resolves-to",
                        )
                        continue
                name_id = self.add_domain(name)
                self.link_cyber(name_id,
                                addr_id,
                                firstseen,
                                lastseen,
                                rel_type="resolves-to")
        for port in record.get("ports", []):
            for script in port.get("scripts", []):
                if script["id"] == "ssl-cert":
                    for cert in script.get("ssl-cert", []):
                        self.add_and_link_cert(cert, addr_id, firstseen,
                                               lastseen)

    def process_passive_record(self, record, observable):
        """Process a `record` from the passive purpose; the query was made
        based on `observable`.

        """
        obs_id = observable["id"]
        obs_type = observable["entity_type"]
        firstseen = record.get("firstseen", record.get("lastseen"))
        lastseen = record.get("lastseen", record.get("firstseen"))
        # Records with no addr fields are only handled for DNS
        if record.get("addr") is None:
            if "targetval" not in record:
                return
            if obs_type != TYPE_DOMAIN:
                return
            if record["recontype"] != "DNS_ANSWER":
                return
            obs_name = observable["value"].lower().rstrip(".")
            new_ids = {}
            for fld in ["value", "targetval"]:
                val = record[fld].lower().rstrip(".")
                if val == obs_name:
                    new_ids[fld] = obs_id
                elif val.endswith(f".{obs_name}"):
                    new_ids[fld] = self.link_domain_parent(
                        val, obs_name, obs_id)
                else:
                    new_ids[fld] = self.add_domain(val)
            try:
                self.link_cyber(
                    new_ids["value"],
                    new_ids["targetval"],
                    firstseen,
                    lastseen,
                    rel_type="resolves-to",
                )
            except ValueError:
                # Workaround for a bug fixed in
                # e38bf150ab70b145bafcdea77351bf4199078401 (GH#1692)
                self.link_cyber(
                    new_ids["value"],
                    new_ids["targetval"],
                    firstseen,
                    lastseen,
                )
            return
        addr_id = self.add_addr(record["addr"])
        if obs_type == TYPE_AS:
            self.link_core(addr_id, obs_id, rel_type="belongs-to")
        if obs_type == TYPE_CERT:
            self.link_cyber(addr_id, obs_id, firstseen, lastseen)
            return
        if obs_type == TYPE_MAC_ADDR:
            self.link_cyber(addr_id,
                            obs_id,
                            firstseen,
                            lastseen,
                            rel_type="resolves-to")
            return
        if obs_type == TYPE_DOMAIN:
            obs_name = observable["value"].lower().rstrip(".")
            value = record["value"].lower().rstrip(".")
            if value == obs_name:
                self.link_cyber(obs_id,
                                addr_id,
                                firstseen,
                                lastseen,
                                rel_type="resolves-to")
            elif value.endswith(f".{obs_name}"):
                new_obs_id = self.link_domain_parent(value, obs_name, obs_id)
                self.link_cyber(new_obs_id,
                                addr_id,
                                firstseen,
                                lastseen,
                                rel_type="resolves-to")
            else:
                self.helper.log_warning(
                    f"BUG! Unexpected record found for domain {obs_name} [{record!r}]"
                )
            return
        # obs_type is either an IP address or a
        # "generator" of IP addresses (e.g., an AS)
        if record["recontype"] == "DNS_ANSWER":
            value = record["value"].lower().rstrip(".")
            name_id = self.add_domain(value)
            self.link_cyber(name_id,
                            addr_id,
                            firstseen,
                            lastseen,
                            rel_type="resolves-to")
            return
        if record["recontype"] == "SSL_SERVER":
            if record.get("source") != "cert":
                return
            if "infos" not in record:
                return
            self.add_and_link_cert(record["infos"], addr_id, firstseen,
                                   lastseen)
            return
        if record["recontype"] == "MAC_ADDRESS":
            self.link_cyber(
                addr_id,
                self.add_mac(record["value"]),
                firstseen,
                lastseen,
                rel_type="resolves-to",
            )
            return
        if record.get("infos", {}).get("service_name") == "scanner":
            # if record["recontype"] == "UDP_HONEYPOT_HIT":  # spoofable
            self.add_and_link_label(
                f"Scanner {record['infos'].get('service_product', '(unknown)')}",
                addr_id,
                color="#ff8178",
            )
        elif record["recontype"] in {
                "HTTP_HONEYPOT_REQUEST",
                "DNS_HONEYPOT_QUERY",
                "TCP_HONEYPOT_HIT",
                "UDP_HONEYPOT_HIT",
        }:
            self.add_and_link_label("Scanner (unknown)",
                                    addr_id,
                                    color="#ff8178")

    def process_data_observable(self, observable):
        if observable["entity_type"] not in TYPES_IP_ADDR:
            return
        result = self.dbase.data.infos_byip(observable["value"])
        if not result:
            return
        if "country_name" in result:
            if "city" in result:
                loc_id = self.add_city(result["city"], result["country_name"],
                                       result["country_code"])
            else:
                loc_id = self.add_country(result["country_name"],
                                          result["country_code"])
            self.link_core(observable["id"], loc_id, rel_type="located-at")
        if "registered_country_name" in result and result[
                "registered_country_name"] != result.get("country_name"):
            country_id = self.add_country(result["registered_country_name"],
                                          result["registered_country_code"])
            self.link_core(observable["id"], country_id, rel_type="located-at")
        if "as_num" in result:
            asn_id = self.add_asn(result["as_num"], result.get("as_name"))
            self.link_core(observable["id"], asn_id, rel_type="belongs-to")

    def process_passive_observable(self, observable):
        obs_type = observable["entity_type"]
        if obs_type == TYPE_AS:
            if not self.use_passive_as:
                return
            flts = [self.dbase.passive.searchasnum(observable["number"])]
        elif obs_type == TYPE_DOMAIN:
            if not self.use_passive_domain:
                return
            flts = [
                self.dbase.passive.searchdns(
                    observable["value"].lower().rstrip(".")),
                self.dbase.passive.searchdns(
                    observable["value"].lower().rstrip("."), reverse=True),
            ]
        elif obs_type in TYPES_IP_ADDR:
            flts = [self.dbase.passive.searchhost(observable["value"])]
        elif obs_type == TYPE_MAC_ADDR:
            flts = [self.dbase.passive.searchmac(observable["value"])]
        elif obs_type == TYPE_CERT:
            flt_args = None
            if "hashes" in observable:
                for algo in ["sha256", "sha1", "md5"]:
                    for entry in observable["hashes"]:
                        if entry["algorithm"].lower() == algo:
                            flt_args = {algo: entry["hash"].lower()}
                            break
                    if flt_args is not None:
                        break
            if flt_args is None:
                for field in ["observable_value", "value"]:
                    if field in observable:
                        value = observable[field]
                        if HEX.search(value):
                            try:
                                flt_args = {
                                    {
                                        32: "md5",
                                        40: "sha1",
                                        64: "sha256"
                                    }[len(value)]: value.lower()
                                }
                            except KeyError:
                                pass
                            else:
                                break
            if flt_args is None:
                self.helper.log_warning(
                    f"Cannot process X509-Certificate observable [{observable!r}]"
                )
                return
            flts = [self.dbase.passive.searchcert(**flt_args)]
        for flt in flts:
            for rec in self.dbase.passive.get(flt):
                self.process_passive_record(rec, observable)

    def process_scans_observable(self, observable):
        obs_type = observable["entity_type"]
        if obs_type == TYPE_AS:
            if not self.use_scans_as:
                return
            flt = self.dbase.nmap.searchasnum(observable["number"])
        elif obs_type == TYPE_DOMAIN:
            if not self.use_scans_domain:
                return
            flt = self.dbase.nmap.searchdomain(
                observable["value"].lower().rstrip("."))
        elif obs_type in TYPES_IP_ADDR:
            flt = self.dbase.nmap.searchhost(observable["value"])
        elif obs_type == TYPE_MAC_ADDR:
            flt = self.dbase.nmap.searchmac(observable["value"])
        elif obs_type == TYPE_CERT:
            flt_args = None
            if "hashes" in observable:
                for algo in ["sha256", "sha1", "md5"]:
                    for entry in observable["hashes"]:
                        if entry["algorithm"].lower() == algo:
                            flt_args = {algo: entry["hash"].lower()}
                            break
                    if flt_args is not None:
                        break
            if flt_args is None:
                for field in ["observable_value", "value"]:
                    if field in observable:
                        value = observable[field]
                        if HEX.search(value):
                            try:
                                flt_args = {
                                    {
                                        32: "md5",
                                        40: "sha1",
                                        64: "sha256"
                                    }[len(value)]: value.lower()
                                }
                            except KeyError:
                                pass
                            else:
                                break
            if flt_args is None:
                self.helper.log_warning(
                    f"Cannot process X509-Certificate observable [{observable!r}]"
                )
                return
            flt = self.dbase.nmap.searchcert(**flt_args)
        for rec in self.dbase.nmap.get(flt):
            self.process_scans_record(rec, observable)

    def _process_message(self, data):
        """Process a message, depending on its type."""
        entity_id = data["entity_id"]
        observable = self.helper.api.stix_cyber_observable.read(id=entity_id)
        if observable is None:
            return
        # Extract TLP
        tlp = "TLP:WHITE"
        for marking_definition in observable["objectMarking"]:
            if marking_definition["definition_type"] == "TLP":
                tlp = marking_definition["definition"]
                break

        if not OpenCTIConnectorHelper.check_max_tlp(tlp, self.max_tlp):
            raise ValueError(
                "Do not send any data, TLP of the observable is greater than MAX TLP"
            )

        if self.use_data:
            self.process_data_observable(observable)
        if self.use_passive:
            self.process_passive_observable(observable)
        if self.use_scans:
            self.process_scans_observable(observable)

    def start(self):
        """Starts the connector."""
        self.helper.listen(self._process_message)
예제 #4
0
class ThreatBusConnector(object):
    def __init__(self):
        config_file_path = os.path.dirname(
            os.path.abspath(__file__)) + "/config.yml"
        config = (yaml.load(open(config_file_path), Loader=yaml.FullLoader)
                  if os.path.isfile(config_file_path) else {})

        # Connector configuration
        self.entity_name = get_config_variable("CONNECTOR_ENTITY_NAME",
                                               ["connector", "entity_name"],
                                               config)
        self.entity_desc = get_config_variable(
            "CONNECTOR_ENTITY_DESCRIPTION",
            ["connector", "entity_description"], config)
        self.forward_all_iocs = get_config_variable(
            "CONNECTOR_FORWARD_ALL_IOCS", ["connector", "forward_all_iocs"],
            config)
        self.threatbus_entity = None

        # Custom configuration for Threat Bus & ZeroMQ plugin endpoint
        self.threatbus_zmq_host = get_config_variable(
            "THREATBUS_ZMQ_HOST", ["threatbus", "zmq_host"], config)
        self.threatbus_zmq_port = get_config_variable(
            "THREATBUS_ZMQ_PORT", ["threatbus", "zmq_port"], config)
        threatbus_snapshot = get_config_variable(
            "THREATBUS_SNAPSHOT",
            ["threatbus", "snapshot"],
            config,
            isNumber=True,
            default=0,
        )

        # Helper initialization
        self.opencti_helper = OpenCTIConnectorHelper(config)
        zmq_endpoint = f"{self.threatbus_zmq_host}:{self.threatbus_zmq_port}"
        self.threatbus_helper = ThreatBusConnectorHelper(
            zmq_endpoint,
            self._handle_threatbus_message,
            self.opencti_helper.log_info,
            self.opencti_helper.log_error,
            subscribe_topics=["stix2/sighting", "stix2/indicator"],
            publish_topic="stix2/indicator",
            snapshot=threatbus_snapshot,
        )

    def _get_threatbus_entity(self) -> int:
        """
        Get the Threat Bus OpenCTI entity. Creates a new entity if it does not
        exist yet.
        """

        # Use cached:
        if self.threatbus_entity is not None:
            return self.threatbus_entity

        # Try and fetch existing:
        threatbus_entity = (
            self.opencti_helper.api.stix_domain_object.get_by_stix_id_or_name(
                name=self.entity_name))
        if threatbus_entity is not None and threatbus_entity.get("id", None):
            self.threatbus_entity = threatbus_entity
            return self.threatbus_entity

        # Create a new one:
        self.opencti_helper.log_info(
            f"Creating new OpenCTI Threat Bus entity '{self.entity_name}'")
        self.threatbus_entity = self.opencti_helper.api.identity.create(
            type="Organization",
            name=self.entity_name,
            description=self.entity_desc,
        )
        return self.threatbus_entity

    def _handle_threatbus_message(self, msg: str):
        """
        Processes a JSON message from Threat Bus (either a serialized STIX-2
        Sighting or STIX-2 Indicator) and forwards it to OpenCTI.
        """
        try:
            stix_msg = parse(msg, allow_custom=True)
        except Exception as e:
            self.opencti_helper.log_error(
                f"Error parsing message from Threat Bus. Expected a STIX-2 Sighting or Indicator: {e}"
            )
            return
        if type(stix_msg) is Sighting:
            self._report_sighting(stix_msg)
        elif type(stix_msg) is Indicator:
            self._handle_indicator(stix_msg)
        else:
            self.opencti_helper.log_warning(
                f"Discarding Threat Bus message with unsupported type: {type(stix_msg)}. Hint: SnapshotRequests are not yet supported."
            )

    def _handle_indicator(self, indicator: Indicator):
        """
        Handles a STIX-2 Indicator update received via Threat Bus. Does nothing
        in case the indicator already exists and the new indicator does not add
        any new fields/values to the existing indicator. By doing so, this
        function effectively avoids double updates that otherwise would result
        in SSE events without a real change.
        @param indicator The STIX-2 Indicator received from Threat Bus
        """
        if type(indicator) is not Indicator:
            self.opencti_helper.log_error(
                f"Error ingesting indicator from Threat Bus. Expected a STIX-2 Indicator: {indicator}"
            )
            return
        if (ThreatBusSTIX2Constants.X_THREATBUS_UPDATE.value
                in indicator.object_properties()
                and indicator.x_threatbus_update == Operation.REMOVE.value):
            # OpenCTI does not support indicator removal via API calls (yet)
            return
        lookup_resp = self.opencti_helper.api.indicator.read(id=indicator.id)
        if not lookup_resp:
            # No indicator with that ID exists already.
            self._create_or_update_indicator(indicator)
            return
        lookup_resp["id"] = lookup_resp["standard_id"]
        lookup_indicator = Indicator(**lookup_resp, allow_custom=True)

        # We found an existing indicator. To avoid double updates in the SSE
        # stream we check if the indicator from Threat Bus adds anything new.

        for prop, new_value in indicator.items():
            if prop == "id" or prop.startswith("x_"):
                continue
            existing_value = lookup_indicator.get(prop, None)
            if existing_value is None or new_value != existing_value:
                self._create_or_update_indicator(indicator)
                return

    def _create_or_update_indicator(self, indicator: Indicator):
        """
        Creates or updates a STIX-2 Indicator in OpenCTI
        @param indicator The STIX-2 Indicator
        """
        ioc_dct = json.loads(indicator.serialize())
        ioc_dct["name"] = ioc_dct.get("name", indicator.id)  #  default to UUID
        ioc_dct["stix_id"] = indicator.id
        del ioc_dct["id"]
        obs_type = ioc_dct.get("x_opencti_main_observable_type", "Unknown")
        ioc_dct["x_opencti_main_observable_type"] = obs_type
        resp = self.opencti_helper.api.indicator.create(**ioc_dct)
        self.opencti_helper.log_info(f"Created or added to indicator: {resp}")

    def _report_sighting(self, sighting: Sighting):
        """
        Reports a STIX-2 Sighting to OpenCTI, modeled as a
        `OpenCTI.stix_sighting_relation`.
        @param sighting The STIX-2 Sighting object to report
        """
        if type(sighting) is not Sighting:
            self.opencti_helper.log_error(
                f"Error reporting sighting from Threat Bus. Expected a STIX-2 Sighting: {sighting}"
            )
            return
        entity_id = self._get_threatbus_entity().get("id", None)
        resp = self.opencti_helper.api.stix_sighting_relationship.create(
            fromId=sighting.sighting_of_ref,
            toId=entity_id,
            createdBy=entity_id,
            first_seen=sighting.first_seen.astimezone().strftime(
                "%Y-%m-%dT%H:%M:%SZ") if sighting.get("first_seen") else None,
            last_seen=sighting.last_seen.astimezone().strftime(
                "%Y-%m-%dT%H:%M:%SZ") if sighting.get("last_seen") else None,
            confidence=50,
            externalReferences=[sighting.sighting_of_ref],
            count=1,
        )
        self.opencti_helper.log_info(f"Created sighting {resp}")

    def _map_to_threatbus(self, data: dict,
                          opencti_action: str) -> Union[Indicator, None]:
        """
        Inspects the given OpenCTI data point and either returns a valid STIX-2
        Indicator or None.
        @param data A dict object with OpenCTI SSE data
        @param opencti_action A string indicating what happened to this item
            (either `create`, `update` or `delete`)
        @return a STIX-2 Indicator or None
        """
        opencti_id: str = data.get("x_opencti_id", None)
        if not opencti_id:
            self.opencti_helper.log_error(
                "Cannot process data without 'x_opencti_id' field")
            return

        event_id = data.get("id", None)
        update = data.get("x_data_update", {})
        added = update.get("add", {})
        added_ids = added.get("x_opencti_stix_ids", [])
        type_ = data.get("type", None)
        if type_ == "indicator" and len(
                added_ids) == 1 and added_ids[0] == event_id:
            # Discard the update if it was empty. An update is empty when the
            # only "changed" attribute is the stix_id and it changed to its own
            # already existing value. Example:
            # data ~ {'id': 'XXX', 'x_data_update': {'add': {'x_opencti_stix_ids': ['XXX']}}}
            return

        if opencti_action == "delete":
            indicator: dict = data
            indicator[ThreatBusSTIX2Constants.X_THREATBUS_UPDATE.
                      value] = Operation.REMOVE.value
        else:
            indicator: dict = self.opencti_helper.api.indicator.read(
                id=opencti_id)
            if not indicator:
                # we are only interested in indicators at this time
                return
            # overwrite custom OpenCTI ID
            indicator["id"] = indicator.get("standard_id")
            if opencti_action == "update":
                indicator[ThreatBusSTIX2Constants.X_THREATBUS_UPDATE.
                          value] = Operation.EDIT.value

        # only propagate indicators that are toggled for detection or the user
        # enabled forwarding of all indicators regardless of the toggle
        detection_enabled: bool = indicator.get("x_opencti_detection", False)
        if not detection_enabled and self.forward_all_iocs is not True:
            return

        return Indicator(**indicator, allow_custom=True)

    def _process_message(self, sse_msg: Event):
        """
        Invoked for every incoming SSE message from the OpenCTI endpoint
        @param sse_msg: the received SSE Event
        """
        try:
            data: dict = json.loads(sse_msg.data).get("data", None)
            if not data:
                return
            indicator = self._map_to_threatbus(data, sse_msg.event)
            if not indicator:
                return
            self.threatbus_helper.send(indicator.serialize())

        except Exception as e:
            self.opencti_helper.log_error(
                f"Error forwarding indicator to Threat Bus: {e}")

    def start(self):
        self.opencti_helper.log_info("Starting Threat Bus connector")

        # Fork a new Thread to communicate with Threat Bus
        self.threatbus_helper.start()
        atexit.register(self.threatbus_helper.stop)

        # Send the main loop into a busy loop for processing OpenCTI events
        self.opencti_helper.listen_stream(self._process_message)
예제 #5
0
class AbuseIPDBConnector:
    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = os.path.dirname(os.path.abspath(__file__)) + "/config.yml"
        config = (
            yaml.load(open(config_file_path), Loader=yaml.FullLoader)
            if os.path.isfile(config_file_path)
            else {}
        )
        self.helper = OpenCTIConnectorHelper(config)
        self.api_key = get_config_variable(
            "ABUSEIPDB_API_KEY", ["abuseipdb", "api_key"], config
        )
        self.max_tlp = get_config_variable(
            "ABUSEIPDB_MAX_TLP", ["abuseipdb", "max_tlp"], config
        )
        self.whitelist_label = self.helper.api.label.create(
            value="whitelist", color="#4caf50"
        )

    @staticmethod
    def extract_abuse_ipdb_category(category_number):
        # Reference: https://www.abuseipdb.com/categories
        mapping = {
            "3": "Fraud Orders",
            "4": "DDOS Attack",
            "5": "FTP Brute-Force",
            "6": "Ping of Death",
            "7": "Phishing",
            "8": "Fraud VOIP",
            "9": "Open Proxy",
            "10": "Web Spam",
            "11": "Email Spam",
            "12": "Blog Spam",
            "13": "VPN IP",
            "14": "Port Scan",
            "15": "Hacking",
            "16": "SQL Injection",
            "17": "Spoofing",
            "18": "Brute Force",
            "19": "Bad Web Bot",
            "20": "Exploited Host",
            "21": "Web App Attack",
            "22": "SSH",
            "23": "IoT Targeted",
        }
        return mapping.get(str(category_number), "unknown category")

    def _process_message(self, data):
        entity_id = data["entity_id"]
        observable = self.helper.api.stix_cyber_observable.read(id=entity_id)
        # Extract TLP
        tlp = "TLP:WHITE"
        for marking_definition in observable["objectMarking"]:
            if marking_definition["definition_type"] == "TLP":
                tlp = marking_definition["definition"]

        if not OpenCTIConnectorHelper.check_max_tlp(tlp, self.max_tlp):
            raise ValueError(
                "Do not send any data, TLP of the observable is greater than MAX TLP"
            )
        # Extract IP from entity data
        observable_id = observable["standard_id"]
        observable_value = observable["value"]
        url = "https://api.abuseipdb.com/api/v2/check"
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/x-www-form-urlencoded",
            "Key": "%s" % self.api_key,
        }
        params = {"maxAgeInDays": 365, "verbose": "True", "ipAddress": observable_value}
        r = requests.get(url, headers=headers, params=params)
        r.raise_for_status()
        data = r.json()
        data = data["data"]
        self.helper.api.stix_cyber_observable.update_field(
            id=observable_id,
            key="x_opencti_score",
            value=str(data["abuseConfidenceScore"]),
        )
        if data["isWhitelisted"]:
            external_reference = self.helper.api.external_reference.create(
                source_name="AbuseIPDB (whitelist)",
                url="https://www.abuseipdb.com/check/" + observable_value,
                description="This IP address is from within our whitelist.",
            )
            self.helper.api.stix_cyber_observable.add_external_reference(
                id=observable_id, external_reference_id=external_reference["id"]
            )
            self.helper.api.stix_cyber_observable.add_label(
                id=observable_id, label_id=self.whitelist_label["id"]
            )
            return "IP found in AbuseIPDB WHITELIST."
        if len(data["reports"]) > 0:
            for report in data["reports"]:
                country = self.helper.api.location.read(
                    filters=[
                        {
                            "key": "x_opencti_aliases",
                            "values": [report["reporterCountryCode"]],
                        }
                    ],
                    getAll=True,
                )
                if country is None:
                    self.helper.log_warning(
                        f"No country found with Alpha 2 code {report['reporterCountryCode']}"
                    )
                else:
                    self.helper.api.stix_sighting_relationship.create(
                        fromId=observable_id,
                        toId=country["id"],
                        count=1,
                        first_seen=report["reportedAt"],
                        last_seen=report["reportedAt"],
                    )
                for category in report["categories"]:
                    category_text = self.extract_abuse_ipdb_category(category)
                    label = self.helper.api.label.create(value=category_text)
                    self.helper.api.stix_cyber_observable.add_label(
                        id=observable_id, label_id=label["id"]
                    )
            return "IP found in AbuseIPDB with reports, knowledge attached."

    # Start the main loop
    def start(self):
        self.helper.listen(self._process_message)