Esempio n. 1
0
class Consumer(threading.Thread):
    def __init__(self, connector, opencti_url, opencti_token):
        threading.Thread.__init__(self)
        self.opencti_url = opencti_url
        self.opencti_token = opencti_token
        self.api = OpenCTIApiClient(self.opencti_url, self.opencti_token)
        self.queue_name = connector['config']['push']
        self.pika_connection = pika.BlockingConnection(
            pika.URLParameters(connector['config']['uri']))
        self.channel = self.pika_connection.channel()
        self.channel.basic_qos(prefetch_count=1)

    def get_id(self):
        if hasattr(self, '_thread_id'):
            return self._thread_id
        for id, thread in threading._active.items():
            if thread is self:
                return id

    def terminate(self):
        thread_id = self.get_id()
        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            thread_id, ctypes.py_object(SystemExit))
        if res > 1:
            ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 0)
            logging.info('Unable to kill the thread')

    def ack_message(self, channel, delivery_tag):
        if channel.is_open:
            logging.info('Message (delivery_tag=' + str(delivery_tag) +
                         ') acknowledged')
            channel.basic_ack(delivery_tag)
        else:
            logging.info('Message (delivery_tag=' + str(delivery_tag) +
                         ') NOT acknowledged (channel closed)')
            pass

    def stop_consume(self, channel):
        if channel.is_open:
            channel.stop_consuming()

    # Callable for consuming a message
    def _process_message(self, channel, method, properties, body):
        data = json.loads(body)
        logging.info('Processing a new message (delivery_tag=' +
                     str(method.delivery_tag) + '), launching a thread...')
        thread = threading.Thread(
            target=self.data_handler,
            args=[self.pika_connection, channel, method.delivery_tag, data])
        thread.start()

        while thread.is_alive():  # Loop while the thread is processing
            self.pika_connection.sleep(1.0)
        logging.info('Message processed, thread terminated')

    # Data handling
    def data_handler(self, connection, channel, delivery_tag, data):
        job_id = data['job_id']
        token = None
        if "token" in data:
            token = data["token"]
        try:
            content = base64.b64decode(data['content']).decode('utf-8')
            types = data['entities_types'] if 'entities_types' in data else []
            update = data['update'] if 'update' in data else False
            if token:
                self.api.set_token(token)
            imported_data = self.api.stix2.import_bundle_from_json(
                content, update, types)
            self.api.set_token(self.opencti_token)
            if job_id is not None:
                messages = []
                by_types = groupby(imported_data, key=lambda x: x['type'])
                for key, grp in by_types:
                    messages.append(str(len(list(grp))) + ' imported ' + key)
                self.api.job.update_job(job_id, 'complete', messages)
            cb = functools.partial(self.ack_message, channel, delivery_tag)
            connection.add_callback_threadsafe(cb)
            return True
        except RequestException as re:
            logging.error('A connection error occurred: { ' + str(re) + ' }')
            logging.info('Message (delivery_tag=' + str(delivery_tag) +
                         ') NOT acknowledged')
            cb = functools.partial(self.stop_consume, channel)
            connection.add_callback_threadsafe(cb)
            return False
        except Exception as e:
            logging.error('An unexpected error occurred: { ' + str(e) + ' }')
            cb = functools.partial(self.ack_message, channel, delivery_tag)
            connection.add_callback_threadsafe(cb)
            if job_id is not None:
                self.api.job.update_job(job_id, 'error', [str(e)])
            return False

    def run(self):
        try:
            # Consume the queue
            logging.info('Thread for queue ' + self.queue_name + ' started')
            self.channel.basic_consume(
                queue=self.queue_name,
                on_message_callback=self._process_message)
            self.channel.start_consuming()
        finally:
            self.channel.stop_consuming()
            logging.info('Thread for queue ' + self.queue_name + ' terminated')
Esempio n. 2
0
class Consumer(threading.Thread):
    def __init__(self, connector, opencti_url, opencti_token):
        threading.Thread.__init__(self)
        self.opencti_url = opencti_url
        self.opencti_token = opencti_token
        self.api = OpenCTIApiClient(self.opencti_url, self.opencti_token)
        self.queue_name = connector["config"]["push"]
        self.pika_connection = pika.BlockingConnection(
            pika.URLParameters(connector["config"]["uri"])
        )
        self.channel = self.pika_connection.channel()
        self.channel.basic_qos(prefetch_count=1)
        self.processing_count = 0

    def get_id(self):
        if hasattr(self, "_thread_id"):
            return self._thread_id
        for id, thread in threading._active.items():
            if thread is self:
                return id

    def terminate(self):
        thread_id = self.get_id()
        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            thread_id, ctypes.py_object(SystemExit)
        )
        if res > 1:
            ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 0)
            logging.info("Unable to kill the thread")

    def nack_message(self, channel, delivery_tag):
        if channel.is_open:
            logging.info("Message (delivery_tag=" + str(delivery_tag) + ") rejected")
            channel.basic_nack(delivery_tag)
        else:
            logging.info(
                "Message (delivery_tag="
                + str(delivery_tag)
                + ") NOT rejected (channel closed)"
            )
            pass

    def ack_message(self, channel, delivery_tag):
        if channel.is_open:
            logging.info(
                "Message (delivery_tag=" + str(delivery_tag) + ") acknowledged"
            )
            channel.basic_ack(delivery_tag)
        else:
            logging.info(
                "Message (delivery_tag="
                + str(delivery_tag)
                + ") NOT acknowledged (channel closed)"
            )
            pass

    def stop_consume(self, channel):
        if channel.is_open:
            channel.stop_consuming()

    # Callable for consuming a message
    def _process_message(self, channel, method, properties, body):
        data = json.loads(body)
        logging.info(
            "Processing a new message (delivery_tag="
            + str(method.delivery_tag)
            + "), launching a thread..."
        )
        thread = threading.Thread(
            target=self.data_handler,
            args=[self.pika_connection, channel, method.delivery_tag, data],
        )
        thread.start()

        while thread.is_alive():  # Loop while the thread is processing
            self.pika_connection.sleep(0.05)
        logging.info("Message processed, thread terminated")

    # Data handling
    def data_handler(self, connection, channel, delivery_tag, data):
        self.processing_count += 1
        job_id = data["job_id"]
        token = None
        if "token" in data:
            token = data["token"]
        try:
            content = base64.b64decode(data["content"]).decode("utf-8")
            types = (
                data["entities_types"]
                if "entities_types" in data and len(data["entities_types"]) > 0
                else None
            )
            update = data["update"] if "update" in data else False
            if token:
                self.api.set_token(token)
            imported_data = self.api.stix2.import_bundle_from_json(
                content, update, types
            )
            self.api.set_token(self.opencti_token)
            if job_id is not None:
                messages = []
                by_types = groupby(imported_data, key=lambda x: x["type"])
                for key, grp in by_types:
                    messages.append(str(len(list(grp))) + " imported " + key)
                self.api.job.update_job(job_id, "complete", messages)
            cb = functools.partial(self.ack_message, channel, delivery_tag)
            connection.add_callback_threadsafe(cb)
            return True
        except RequestException as re:
            logging.error("A connection error occurred: { " + str(re) + " }")
            time.sleep(60)
            logging.info(
                "Message (delivery_tag=" + str(delivery_tag) + ") NOT acknowledged"
            )
            cb = functools.partial(self.nack_message, channel, delivery_tag)
            connection.add_callback_threadsafe(cb)
            return False
        except Exception as e:
            error = str(e)
            if "UnsupportedError" not in error and self.processing_count <= 5:
                time.sleep(2)
                logging.info(
                    "Message (delivery_tag="
                    + str(delivery_tag)
                    + ") reprocess (retry nb: "
                    + str(self.processing_count)
                    + ")"
                )
                self.data_handler(connection, channel, delivery_tag, data)
            else:
                logging.error(str(e))
                self.processing_count = 0
                cb = functools.partial(self.ack_message, channel, delivery_tag)
                connection.add_callback_threadsafe(cb)
                if job_id is not None:
                    self.api.job.update_job(job_id, "error", [str(e)])
                return False

    def run(self):
        try:
            # Consume the queue
            logging.info("Thread for queue " + self.queue_name + " started")
            self.channel.basic_consume(
                queue=self.queue_name, on_message_callback=self._process_message
            )
            self.channel.start_consuming()
        finally:
            self.channel.stop_consuming()
            logging.info("Thread for queue " + self.queue_name + " terminated")