Exemple #1
0
    def convert_to_entity(self, opencti_response: List[Dict],
                          helper: OpenCTIConnectorHelper) -> List[Entity]:
        entities = []
        for item in opencti_response:
            _id = item.get("standard_id")
            item_values = set()

            for relevant_field in self.fields:
                elem = item.get(relevant_field, None)
                if elem:
                    if type(elem) == list:
                        item_values.update(elem)
                    elif type(elem) == str:
                        item_values.add(elem)

            indicators = []
            for value in item_values:
                # Remove SDO names which are defined to be excluded in the entity config
                if value.lower() in self.exclude_values:
                    helper.log_debug(
                        f"Entity: Discarding value '{value}' due to explicit exclusion as defined in {self.exclude_values}"
                    )
                    continue

                value = re.escape(value)
                value = f"\\b{value}\\b"
                try:
                    compiled_re = re.compile(value, re.IGNORECASE)
                    indicators.append(compiled_re)
                except re.error as e:
                    helper.log_error(
                        f"Entity {self.name}: Unable to create regex from value '{value}' ({e})"
                    )

            if len(indicators) == 0:
                continue

            entity = Entity(
                name=self.name,
                stix_class=self.stix_class,
                stix_id=_id,
                values=item_values,
                regex=indicators,
                omit_match_in=self.omit_match_in,
            )
            entities.append(entity)

        return entities
Exemple #2
0
class ReportImporter:
    def __init__(self) -> None:
        # Instantiate the connector helper from config
        base_path = os.path.dirname(os.path.abspath(__file__))
        config_file_path = base_path + "/../config.yml"
        config = (yaml.load(open(config_file_path), Loader=yaml.FullLoader)
                  if os.path.isfile(config_file_path) else {})

        self.helper = OpenCTIConnectorHelper(config)
        self.create_indicator = get_config_variable(
            "IMPORT_DOCUMENT_CREATE_INDICATOR",
            ["import_document", "create_indicator"],
            config,
        )

        # Load Entity and Observable configs
        observable_config_file = base_path + "/config/observable_config.ini"
        entity_config_file = base_path + "/config/entity_config.ini"

        if os.path.isfile(observable_config_file) and os.path.isfile(
                entity_config_file):
            self.observable_config = self._parse_config(
                observable_config_file, Observable)
        else:
            raise FileNotFoundError(f"{observable_config_file} was not found")

        if os.path.isfile(entity_config_file):
            self.entity_config = self._parse_config(entity_config_file,
                                                    EntityConfig)
        else:
            raise FileNotFoundError(f"{entity_config_file} was not found")

    def _process_message(self, data: Dict) -> str:
        self.helper.log_info("Processing new message")
        file_name = self._download_import_file(data)
        entity_id = data.get("entity_id", None)
        bypass_validation = data.get("bypass_validation", False)
        entity = (self.helper.api.stix_domain_object.read(
            id=entity_id) if entity_id is not None else None)
        if self.helper.get_only_contextual() and entity is None:
            return "Connector is only contextual and entity is not defined. Nothing was imported"

        # Retrieve entity set from OpenCTI
        entity_indicators = self._collect_stix_objects(self.entity_config)

        # Parse report
        parser = ReportParser(self.helper, entity_indicators,
                              self.observable_config)
        parsed = parser.run_parser(file_name, data["file_mime"])
        os.remove(file_name)

        if not parsed:
            return "No information extracted from report"

        # Process parsing results
        self.helper.log_debug("Results: {}".format(parsed))
        observables, entities = self._process_parsing_results(parsed, entity)
        # Send results to OpenCTI
        observable_cnt = self._process_parsed_objects(entity, observables,
                                                      entities,
                                                      bypass_validation,
                                                      file_name)
        entity_cnt = len(entities)

        if self.helper.get_validate_before_import() and not bypass_validation:
            return "Generated bundle sent for validation"
        else:
            return (
                f"Sent {observable_cnt} observables, 1 report update and {entity_cnt} entity connections as stix "
                f"bundle for worker import ")

    def start(self) -> None:
        self.helper.listen(self._process_message)

    def _download_import_file(self, data: Dict) -> str:
        file_fetch = data["file_fetch"]
        file_uri = self.helper.opencti_url + file_fetch

        # Downloading and saving file to connector
        self.helper.log_info("Importing the file " + file_uri)
        file_name = os.path.basename(file_fetch)
        file_content = self.helper.api.fetch_opencti_file(file_uri, True)

        with open(file_name, "wb") as f:
            f.write(file_content)

        return file_name

    def _collect_stix_objects(
            self, entity_config_list: List[EntityConfig]) -> List[Entity]:
        base_func = self.helper.api
        entity_list = []
        for entity_config in entity_config_list:
            func_format = entity_config.stix_class
            try:
                custom_function = getattr(base_func, func_format)
                entries = custom_function.list(
                    getAll=True,
                    filters=entity_config.filter,
                    customAttributes=entity_config.custom_attributes,
                )
                entity_list += entity_config.convert_to_entity(
                    entries, self.helper)
            except AttributeError:
                e = "Selected parser format is not supported: {}".format(
                    func_format)
                raise NotImplementedError(e)

        return entity_list

    @staticmethod
    def _parse_config(config_file: str,
                      file_class: Callable) -> List[BaseModel]:
        config = MyConfigParser()
        config.read(config_file)

        config_list = []
        for section, content in config.as_dict().items():
            content["name"] = section
            config_object = file_class(**content)
            config_list.append(config_object)

        return config_list

    def _process_parsing_results(
            self, parsed: List[Dict],
            context_entity: Dict) -> (List[SimpleObservable], List[str]):
        observables = []
        entities = []
        if context_entity is not None:
            object_markings = [
                x["standard_id"]
                for x in context_entity.get("objectMarking", [])
            ]
            # external_references = [x['standard_id'] for x in report.get('externalReferences', [])]
            # labels = [x['standard_id'] for x in report.get('objectLabel', [])]
            author = context_entity.get("createdBy")
        else:
            object_markings = []
            author = None
        if author is not None:
            author = author.get("standard_id", None)
        for match in parsed:
            if match[RESULT_FORMAT_TYPE] == OBSERVABLE_CLASS:
                if match[RESULT_FORMAT_CATEGORY] == "Vulnerability.name":
                    entity = self.helper.api.vulnerability.read(
                        filters={
                            "key": "name",
                            "values": [match[RESULT_FORMAT_MATCH]]
                        })
                    if entity is None:
                        self.helper.log_info(
                            f"Vulnerability with name '{match[RESULT_FORMAT_MATCH]}' could not be "
                            f"found. Is the CVE Connector activated?")
                        continue

                    entities.append(entity["standard_id"])
                elif match[
                        RESULT_FORMAT_CATEGORY] == "Attack-Pattern.x_mitre_id":
                    entity = self.helper.api.attack_pattern.read(
                        filters={
                            "key": "x_mitre_id",
                            "values": [match[RESULT_FORMAT_MATCH]],
                        })
                    if entity is None:
                        self.helper.log_info(
                            f"AttackPattern with MITRE ID '{match[RESULT_FORMAT_MATCH]}' could not be "
                            f"found. Is the MITRE Connector activated?")
                        continue

                    entities.append(entity["standard_id"])
                else:
                    observable = SimpleObservable(
                        id=OpenCTIStix2Utils.generate_random_stix_id(
                            "x-opencti-simple-observable"),
                        key=match[RESULT_FORMAT_CATEGORY],
                        value=match[RESULT_FORMAT_MATCH],
                        x_opencti_create_indicator=self.create_indicator,
                        object_marking_refs=object_markings,
                        created_by_ref=author,
                        # labels=labels,
                        # external_references=external_references
                    )
                    observables.append(observable)

            elif match[RESULT_FORMAT_TYPE] == ENTITY_CLASS:
                entities.append(match[RESULT_FORMAT_MATCH])
            else:
                self.helper.log_info("Odd data received: {}".format(match))

        return observables, entities

    def _process_parsed_objects(
        self,
        entity: Dict,
        observables: List,
        entities: List,
        bypass_validation: bool,
        file_name: str,
    ) -> int:

        if len(observables) == 0 and len(entities) == 0:
            return 0

        if entity is not None and entity["entity_type"] == "Report":
            report = Report(
                id=entity["standard_id"],
                name=entity["name"],
                description=entity["description"],
                published=self.helper.api.stix2.format_date(entity["created"]),
                report_types=entity["report_types"],
                object_refs=observables + entities,
                allow_custom=True,
            )
            observables.append(report)
        elif entity is not None:
            # TODO, relate all object to the entity
            entity_stix_bundle = self.helper.api.stix2.export_entity(
                entity["entity_type"], entity["id"])
            observables = observables + entity_stix_bundle["objects"]
        else:
            timestamp = int(time.time())
            now = datetime.utcfromtimestamp(timestamp)
            report = Report(
                name=file_name,
                description="Automatic import",
                published=now,
                report_types=["threat-report"],
                object_refs=observables + entities,
                allow_custom=True,
            )
            observables.append(report)
        bundles_sent = []
        if len(observables) > 0:
            bundle = Bundle(objects=observables, allow_custom=True).serialize()
            bundles_sent = self.helper.send_stix2_bundle(
                bundle=bundle,
                update=True,
                bypass_validation=bypass_validation,
                file_name=file_name + ".json",
                entity_id=entity["id"] if entity is not None else None,
            )

        # len() - 1 because the report update increases the count by one
        return len(bundles_sent) - 1
Exemple #3
0
class VirusTotalConnector:
    """VirusTotal connector."""

    _CONNECTOR_RUN_INTERVAL_SEC = 60 * 60
    _API_URL = "https://www.virustotal.com/api/v3"

    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = Path(
            __file__).parent.parent.resolve() / "config.yml"

        config = (yaml.load(open(config_file_path, encoding="utf-8"),
                            Loader=yaml.FullLoader)
                  if config_file_path.is_file() else {})
        self.helper = OpenCTIConnectorHelper(config)
        token = get_config_variable("VIRUSTOTAL_TOKEN",
                                    ["virustotal", "token"], config)
        self.max_tlp = get_config_variable("VIRUSTOTAL_MAX_TLP",
                                           ["virustotal", "max_tlp"], config)

        self.client = VirusTotalClient(self._API_URL, token)

        # Cache to store YARA rulesets.
        self.yara_cache = {}

    def _create_yara_indicator(
            self,
            yara: dict,
            valid_from: Optional[int] = None) -> Optional[Indicator]:
        """Create an indicator containing the YARA rule from VirusTotal."""
        valid_from_date = (datetime.datetime.min if valid_from is None else
                           datetime.datetime.utcfromtimestamp(valid_from))
        ruleset_id = yara.get("ruleset_id", "No ruleset id provided")
        self.helper.log_info(f"[VirusTotal] Retrieving ruleset {ruleset_id}")

        # Lookup in the cache for the ruleset id, otherwise, request VirusTotal API.
        if ruleset_id in self.yara_cache:
            self.helper.log_debug(
                f"Retrieving YARA ruleset {ruleset_id} from cache.")
            ruleset = self.yara_cache[ruleset_id]
        else:
            self.helper.log_debug(
                f"Retrieving YARA ruleset {ruleset_id} from API.")
            ruleset = self.client.get_yara_ruleset(ruleset_id)
            self.yara_cache[ruleset_id] = ruleset

        # Parse the rules to find the correct one.
        parser = plyara.Plyara()
        rules = parser.parse_string(ruleset["data"]["attributes"]["rules"])
        rule_name = yara.get("rule_name", "No ruleset name provided")
        rule = [r for r in rules if r["rule_name"] == rule_name]
        if len(rule) == 0:
            self.helper.log_warning(f"No YARA rule for rule name {rule_name}")
            return None

        return self.helper.api.indicator.create(
            name=yara.get("rule_name", "No rulename provided"),
            description=json.dumps({
                "description":
                yara.get("description", "No description provided"),
                "author":
                yara.get("author", "No author provided"),
                "source":
                yara.get("source", "No source provided"),
                "ruleset_id":
                ruleset_id,
                "ruleset_name":
                yara.get("ruleset_name", "No ruleset name provided"),
            }),
            pattern=plyara.utils.rebuild_yara_rule(rule[0]),
            pattern_type="yara",
            valid_from=self.helper.api.stix2.format_date(valid_from_date),
            x_opencti_main_observable_type="StixFile",
        )

    def _process_file(self, observable):
        json_data = self.client.get_file_info(observable["observable_value"])
        if "error" in json_data:
            if json_data["error"]["message"] == "Quota exceeded":
                self.helper.log_info("Quota reached, waiting 1 hour.")
                sleep(self._CONNECTOR_RUN_INTERVAL_SEC)
            elif "not found" in json_data["error"]["message"]:
                self.helper.log_info("File not found on VirusTotal.")
                return "File not found on VirusTotal."
            else:
                raise ValueError(json_data["error"]["message"])
        if "data" in json_data:
            data = json_data["data"]
            attributes = data["attributes"]
            # Update the current observable
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=observable["id"],
                input={
                    "key": "hashes.MD5",
                    "value": attributes["md5"]
                },
            )
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=final_observable["id"],
                input={
                    "key": "hashes.SHA-1",
                    "value": attributes["sha1"]
                },
            )
            final_observable = self.helper.api.stix_cyber_observable.update_field(
                id=final_observable["id"],
                input={
                    "key": "hashes.SHA-256",
                    "value": attributes["sha256"]
                },
            )
            if observable["entity_type"] == "StixFile":
                self.helper.api.stix_cyber_observable.update_field(
                    id=final_observable["id"],
                    input={
                        "key": "size",
                        "value": str(attributes["size"])
                    },
                )
                if observable["name"] is None and len(attributes["names"]) > 0:
                    self.helper.api.stix_cyber_observable.update_field(
                        id=final_observable["id"],
                        input={
                            "key": "name",
                            "value": attributes["names"][0]
                        },
                    )
                    del attributes["names"][0]

            if len(attributes["names"]) > 0:
                self.helper.api.stix_cyber_observable.update_field(
                    id=final_observable["id"],
                    input={
                        "key": "x_opencti_additional_names",
                        "value": attributes["names"],
                    },
                )

            # Create external reference
            external_reference = self.helper.api.external_reference.create(
                source_name="VirusTotal",
                url="https://www.virustotal.com/gui/file/" +
                attributes["sha256"],
                description=attributes["magic"],
            )

            # Create tags
            for tag in attributes["tags"]:
                tag_vt = self.helper.api.label.create(value=tag,
                                                      color="#0059f7")
                self.helper.api.stix_cyber_observable.add_label(
                    id=final_observable["id"], label_id=tag_vt["id"])

            self.helper.api.stix_cyber_observable.add_external_reference(
                id=final_observable["id"],
                external_reference_id=external_reference["id"],
            )

            if "crowdsourced_yara_results" in attributes:
                self.helper.log_info(
                    "[VirusTotal] adding yara results to file.")

                # Add YARA rules (only if a rule is given).
                yaras = list(
                    filter(
                        None,
                        [
                            self._create_yara_indicator(
                                yara, attributes.get("creation_date", None))
                            for yara in attributes["crowdsourced_yara_results"]
                        ],
                    ))

                self.helper.log_debug(
                    f"[VirusTotal] Indicators created: {yaras}")

                # Create the relationships (`related-to`) between the yaras and the file.
                for yara in yaras:
                    self.helper.api.stix_core_relationship.create(
                        fromId=final_observable["id"],
                        toId=yara["id"],
                        relationship_type="related-to",
                    )

            return "File found on VirusTotal, knowledge attached."

    def _process_message(self, data):
        entity_id = data["entity_id"]
        observable = self.helper.api.stix_cyber_observable.read(id=entity_id)
        # Extract TLP
        tlp = "TLP:WHITE"
        for marking_definition in observable["objectMarking"]:
            if marking_definition["definition_type"] == "TLP":
                tlp = marking_definition["definition"]
        if not OpenCTIConnectorHelper.check_max_tlp(tlp, self.max_tlp):
            raise ValueError(
                "Do not send any data, TLP of the observable is greater than MAX TLP"
            )
        return self._process_file(observable)

    def start(self):
        """Start the main loop."""
        self.helper.listen(self._process_message)
class ImportFileStix:
    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = os.path.dirname(os.path.abspath(__file__)) + "/config.yml"
        config = (
            yaml.load(open(config_file_path), Loader=yaml.FullLoader)
            if os.path.isfile(config_file_path)
            else {}
        )
        self.helper = OpenCTIConnectorHelper(config)

    def _process_message(self, data: Dict) -> str:
        file_fetch = data["file_fetch"]
        file_uri = self.helper.opencti_url + file_fetch
        self.helper.log_info(f"Importing the file {file_uri}")

        file_content = self.helper.api.fetch_opencti_file(file_uri)
        if data["file_mime"] == "text/xml":
            self.helper.log_debug("Stix1 file. Attempting conversion")
            initialize_options()
            file_content = elevate(file_content)

        entity_id = data.get("entity_id", None)
        if entity_id:
            self.helper.log_debug("Contextual import.")

            bundle = parse(file_content)["objects"]

            if self._contains_report(bundle):
                self.helper.log_debug("Bundle contains report.")
            else:
                self.helper.log_debug("No Report in Stix file. Updating current report")
                bundle = self._update_report(bundle, entity_id)

            file_content = Bundle(objects=bundle).serialize()

        bundles_sent = self.helper.send_stix2_bundle(file_content)
        return "Sent " + str(len(bundles_sent)) + " stix bundle(s) for worker import"

    # Start the main loop
    def start(self) -> None:
        self.helper.listen(self._process_message)

    @staticmethod
    def _contains_report(bundle: List) -> bool:
        for elem in bundle:
            if type(elem) == Report:
                return True
        return False

    def _update_report(self, bundle: List, entity_id: int) -> List:
        report = self.helper.api.report.read(id=entity_id)
        # The entity_id can be any SDO
        if report:
            report = Report(
                id=report["standard_id"],
                name=report["name"],
                description=report["description"],
                published=self.helper.api.stix2.format_date(report["created"]),
                report_types=report["report_types"],
                object_refs=bundle,
            )
            bundle.append(report)
        return bundle
Exemple #5
0
class ReportImporter:
    def __init__(self) -> None:
        # Instantiate the connector helper from config
        base_path = os.path.dirname(os.path.abspath(__file__))
        config_file_path = base_path + "/../config.yml"
        config = (
            yaml.load(open(config_file_path), Loader=yaml.FullLoader)
            if os.path.isfile(config_file_path)
            else {}
        )
        self.helper = OpenCTIConnectorHelper(config)
        self.create_indicator = get_config_variable(
            "IMPORT_REPORT_CREATE_INDICATOR",
            ["import_report", "create_indicator"],
            config,
        )

        # Load Entity and Observable configs
        observable_config_file = base_path + "/config/observable_config.ini"
        entity_config_file = base_path + "/config/entity_config.ini"

        if os.path.isfile(observable_config_file) and os.path.isfile(
            entity_config_file
        ):
            self.observable_config = self._parse_config(
                observable_config_file, Observable
            )
        else:
            raise FileNotFoundError(f"{observable_config_file} was not found")

        if os.path.isfile(entity_config_file):
            self.entity_config = self._parse_config(entity_config_file, EntityConfig)
        else:
            raise FileNotFoundError(f"{entity_config_file} was not found")

    def _process_message(self, data: Dict) -> str:
        file_name = self._download_import_file(data)
        entity_id = data.get("entity_id", None)
        if self._check_context(entity_id):
            raise ValueError(
                "No context defined, connector is get_only_contextual true"
            )

        # Retrieve entity set from OpenCTI
        entity_indicators = self._collect_stix_objects(self.entity_config)

        # Parse peport
        parser = ReportParser(self.helper, entity_indicators, self.observable_config)
        parsed = parser.run_parser(file_name, data["file_mime"])
        os.remove(file_name)

        if not parsed:
            return "No information extracted from report"

        # Process parsing results
        self.helper.log_debug("Results: {}".format(parsed))
        observables, entities = self._process_parsing_results(parsed)
        report = self.helper.api.report.read(id=entity_id)
        # Send results to OpenCTI
        observable_cnt = self._process_observables(report, observables)
        entity_cnt = self._process_entities(report, entities)

        return f"Sent {observable_cnt} stix bundle(s) and {entity_cnt} entity connections for worker import"

    def start(self) -> None:
        self.helper.listen(self._process_message)

    def _download_import_file(self, data: Dict) -> str:
        file_fetch = data["file_fetch"]
        file_uri = self.helper.opencti_url + file_fetch

        # Downloading and saving file to connector
        self.helper.log_info("Importing the file " + file_uri)
        file_name = os.path.basename(file_fetch)
        file_content = self.helper.api.fetch_opencti_file(file_uri, True)

        with open(file_name, "wb") as f:
            f.write(file_content)

        return file_name

    def _check_context(self, entity_id: str) -> bool:
        is_context = entity_id and len(entity_id) > 0
        return self.helper.get_only_contextual() and not is_context

    def _collect_stix_objects(
        self, entity_config_list: List[EntityConfig]
    ) -> List[Entity]:
        base_func = self.helper.api
        entity_list = []
        for entity_config in entity_config_list:
            func_format = entity_config.stix_class
            try:
                custom_function = getattr(base_func, func_format)
                entries = custom_function.list(
                    getAll=True, filters=entity_config.filter
                )
                entity_list += entity_config.convert_to_entity(entries)
            except AttributeError:
                e = "Selected parser format is not supported: {}".format(func_format)
                raise NotImplementedError(e)

        return entity_list

    @staticmethod
    def _parse_config(config_file: str, file_class: Callable) -> List[BaseModel]:
        config = MyConfigParser()
        config.read(config_file)

        config_list = []
        for section, content in config.as_dict().items():
            content["name"] = section
            config_object = file_class(**content)
            config_list.append(config_object)

        return config_list

    def _process_parsing_results(
        self, parsed: List[Dict]
    ) -> (List[SimpleObservable], List[str]):
        observables = []
        entities = []
        for match in parsed:
            if match[RESULT_FORMAT_TYPE] == OBSERVABLE_CLASS:
                # Hardcoded exceptions since SimpleObservable doesn't support those types yet
                if match[RESULT_FORMAT_CATEGORY] == "Vulnerability.name":
                    observable = self.helper.api.vulnerability.read(
                        filters={"key": "name", "values": [match[RESULT_FORMAT_MATCH]]}
                    )
                    if observable is None:
                        self.helper.log_info(
                            f"Vulnerability with name '{match[RESULT_FORMAT_MATCH]}' could not be "
                            f"found. Is the CVE Connector activated?"
                        )
                        continue
                elif match[RESULT_FORMAT_CATEGORY] == "Attack-Pattern.x_mitre_id":
                    observable = self.helper.api.attack_pattern.read(
                        filters={
                            "key": "x_mitre_id",
                            "values": [match[RESULT_FORMAT_MATCH]],
                        }
                    )
                    if observable is None:
                        self.helper.log_info(
                            f"AttackPattern with MITRE ID '{match[RESULT_FORMAT_MATCH]}' could not be "
                            f"found. Is the MITRE Connector activated?"
                        )
                        continue

                else:
                    observable = self.helper.api.stix_cyber_observable.create(
                        simple_observable_id=OpenCTIStix2Utils.generate_random_stix_id(
                            "x-opencti-simple-observable"
                        ),
                        simple_observable_key=match[RESULT_FORMAT_CATEGORY],
                        simple_observable_value=match[RESULT_FORMAT_MATCH],
                        createIndicator=self.create_indicator,
                    )

                observables.append(observable["id"])

            elif match[RESULT_FORMAT_TYPE] == ENTITY_CLASS:
                entities.append(match[RESULT_FORMAT_MATCH])
            else:
                self.helper.log_info("Odd data received: {}".format(match))

        return observables, entities

    def _process_observables(self, report: Dict, observables: List) -> int:
        if report is None:
            self.helper.log_error(
                "No report found! This is a purely contextual connector and this should not happen"
            )

        if len(observables) == 0:
            return 0

        report = self.helper.api.report.create(
            id=report["standard_id"],
            name=report["name"],
            description=report["description"],
            published=self.helper.api.stix2.format_date(report["created"]),
            report_types=report["report_types"],
            update=True,
        )

        for observable in observables:
            self.helper.api.report.add_stix_object_or_stix_relationship(
                id=report["id"], stixObjectOrStixRelationshipId=observable
            )

        return len(observables)

    def _process_entities(self, report: Dict, entities: List) -> int:
        if report:
            for stix_object in entities:
                self.helper.api.report.add_stix_object_or_stix_relationship(
                    id=report["id"], stixObjectOrStixRelationshipId=stix_object
                )

        return len(entities)
Exemple #6
0
class RiskIQConnector:
    """RiskIQ Connector main class."""

    _DEFAULT_AUTHOR = "RiskIQ"

    # Default run interval
    _CONNECTOR_RUN_INTERVAL_SEC = 60
    _STATE_LATEST_RUN_TIMESTAMP = "latest_run_timestamp"

    def __init__(self):
        # Instantiate the connector helper from config
        config_file_path = Path(
            __file__).parent.parent.resolve() / "config.yml"

        config = (yaml.load(open(config_file_path, encoding="utf8"),
                            Loader=yaml.FullLoader)
                  if config_file_path.is_file() else {})

        self.helper = OpenCTIConnectorHelper(config)

        self.base_url = get_config_variable("RISKIQ_BASE_URL",
                                            ["riskiq", "base_url"], config)
        self.interval_sec = get_config_variable("RISKIQ_INTERVAL_SEC",
                                                ["riskiq", "interval_sec"],
                                                config)
        user = get_config_variable("RISKIQ_USER", ["riskiq", "user"], config)
        password = get_config_variable("RISKIQ_PASSWORD",
                                       ["riskiq", "password"], config)
        # Create the author for all reports.
        self.author = Identity(
            name=self._DEFAULT_AUTHOR,
            identity_class="organization",
            description=
            " RiskIQ is a cyber security company based in San Francisco, California."
            " It provides cloud - based software as a service(SaaS) for organizations"
            " to detect phishing, fraud, malware, and other online security threats.",
            confidence=self.helper.connect_confidence_level,
        )
        # Initialization of the client
        self.client = RiskIQClient(self.base_url, user, password)

    @staticmethod
    def _current_unix_timestamp() -> int:
        return int(time.time())

    def _get_interval(self) -> int:
        return int(self.interval_sec)

    @staticmethod
    def _get_state_value(state: Optional[Mapping[str, Any]],
                         key: str,
                         default: Optional[Any] = None) -> Any:
        if state is not None:
            return state.get(key, default)
        return default

    def _initiate_work(self, timestamp: int) -> str:
        now = datetime.datetime.utcfromtimestamp(timestamp)
        friendly_name = "RiskIQ run @ " + now.strftime("%Y-%m-%d %H:%M:%S")
        work_id = self.helper.api.work.initiate_work(self.helper.connect_id,
                                                     friendly_name)
        self.helper.log_info(f"[RiskIQ] workid {work_id} initiated")
        return work_id

    def _is_scheduled(self, last_run: Optional[int],
                      current_time: int) -> bool:
        if last_run is None:
            self.helper.log_info("RiskIQ connector clean run")
            return True

        time_diff = current_time - last_run
        return time_diff >= self._get_interval()

    def _get_next_interval(self, run_interval: int, timestamp: int,
                           last_run: int) -> int:
        """Get the delay for the next interval."""
        next_run = self._get_interval() - (timestamp - last_run)
        return min(run_interval, next_run)

    def _load_state(self) -> dict[str, Any]:
        current_state = self.helper.get_state()
        if not current_state:
            return {}
        return current_state

    @classmethod
    def _sleep(cls, delay_sec: Optional[int] = None) -> None:
        sleep_delay = (delay_sec if delay_sec is not None else
                       cls._CONNECTOR_RUN_INTERVAL_SEC)
        time.sleep(sleep_delay)

    def run(self):
        """Run RiskIQ connector."""
        self.helper.log_info("Starting RiskIQ connector...")

        while True:
            self.helper.log_info("Running RiskIQ connector...")
            run_interval = self._CONNECTOR_RUN_INTERVAL_SEC

            try:
                self.helper.log_info(f"Connector interval sec: {run_interval}")
                timestamp = self._current_unix_timestamp()
                current_state = self._load_state()
                self.helper.log_info(f"[RiskIQ] loaded state: {current_state}")

                last_run = self._get_state_value(
                    current_state, self._STATE_LATEST_RUN_TIMESTAMP)
                if self._is_scheduled(last_run, timestamp):
                    work_id = self._initiate_work(timestamp)
                    new_state = current_state.copy()
                    last_article = self._get_state_value(
                        current_state,
                        ArticleImporter._LATEST_ARTICLE_TIMESTAMP)

                    self.helper.log_info(f"[RiskIQ] last run: {last_run}")
                    last_article_date = (
                        timestamp_to_datetime(last_article).date()
                        if last_run else None)
                    self.helper.log_debug(
                        f"[RiskIQ] retrieving data from {last_article_date}")
                    response = self.client.get_articles(last_article_date)

                    if self.client.is_correct(response):
                        for article in response["articles"]:
                            importer = ArticleImporter(self.helper, article,
                                                       self.author)
                            importer_state = importer.run(
                                work_id, current_state)
                            if importer_state:
                                self.helper.log_info(
                                    f"[RiskIQ] Updating state {importer_state}"
                                )
                                new_state.update(importer_state)

                            # Set the new state
                            new_state[
                                self.
                                _STATE_LATEST_RUN_TIMESTAMP] = self._current_unix_timestamp(
                                )
                            self.helper.log_info(
                                f"[RiskIQ] Storing new state: {new_state}")
                            self.helper.set_state(new_state)
                    else:
                        self.helper.log_warning(
                            "[RiskIQ] failed to retrieve articles")
                        run_interval = self._CONNECTOR_RUN_INTERVAL_SEC
                        self.helper.log_info(
                            f"[RiskIQ] next run in {run_interval} seconds")
                else:
                    run_interval = self._get_next_interval(
                        run_interval, timestamp, last_run)
                    self.helper.log_info(
                        f"[RiskIQ] Connector will not run, next run in {run_interval} seconds"
                    )

                self._sleep(delay_sec=run_interval)
            except (KeyboardInterrupt, SystemExit):
                self.helper.log_info("RiskIQ connector stop")
                sys.exit(0)
            except Exception as e:
                self.helper.log_error(str(e))
                sys.exit(0)