예제 #1
0
파일: app.py 프로젝트: halsayed/xi-iot
class VideoPipeline:
    def __init__(self, func: Callable[[bytes], bytes]):
        self.nc = None
        self.func = func
        self.msg_idx = 0

        self.nats_endpoint = _env_var('NATS_ENDPOINT', 'NATS broker endpoint')
        self.src_topic = _env_var('NATS_SRC_TOPIC', 'source NATS topic')
        self.dst_topic = _env_var('NATS_DST_TOPIC', 'destination NATS topic')
        logging.info("[{}]: {} -> {}".format(self.nats_endpoint,
                                             self.src_topic, self.dst_topic))

    async def run(self, loop):
        self.nc = NATS()
        await self.nc.connect(loop=loop, servers=[str(self.nats_endpoint)])
        logging.info('connected to NATS')
        await self.nc.subscribe(str(self.src_topic), cb=self._message_handler)

    async def _message_handler(self, msg):
        ds_msg = xi_iot_pb2.DataStreamMessage()
        ds_msg.ParseFromString(msg.data)
        try:
            ds_msg.payload = self.func(ds_msg.payload)  # TODO error handling
        except Exception as e:
            logging.error("failed to process message #{}: {}".format(
                self.msg_idx, e))
        if self.msg_idx % 100 == 0:
            logging.info("processed message #{}".format(self.msg_idx))
        try:
            await self.nc.publish(str(self.dst_topic),
                                  ds_msg.SerializeToString())
        except Exception as e:
            logging.error("failed to publish message #{}: {}".format(
                self.msg_idx, e))
        self.msg_idx += 1

    def __del__(self):
        if self.nc:
            self.nc.drain()
예제 #2
0
class Matcher:
    fut_stop: asyncio.Future

    def __init__(self):
        self.nc = NATS()
        client_stub = pydgraph.DgraphClientStub("dgraph:9080")
        self.dg = pydgraph.DgraphClient(client_stub)
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(signal.SIGTERM, self.stop)
        loop.add_signal_handler(signal.SIGHUP, self.stop)
        loop.add_signal_handler(signal.SIGINT, self.stop)
        self.fut_stop = loop.create_future()

    async def setup(self):
        await self.nc.connect("nats://nats")
        await self.nc.subscribe(TOPIC, "matcher", self.on_update)
        print("Listening on ", TOPIC)
        self.setup_schema()

    def setup_schema(self):
        schema = """
        email_address: string @index(exact) .
        id: string @index(exact) .
        """

        op = pydgraph.Operation(schema=schema)
        self.dg.alter(op)

    async def reply(self, msg, resp):
        if msg.reply:
            await self.nc.publish(msg.reply, resp)

    async def on_update(self, msg) -> None:
        tr = None
        try:
            tr = ExtractionTrigger.from_bytes(msg.data)
        except Exception:
            traceback.print_exc()
            await self.reply(msg, orjson.dumps({"ok": False}))
            return

        txn = self.dg.txn()
        nq = None
        query = None
        try:
            nq = tr.nquads()
            query = tr.query()
            mutation = txn.create_mutation(set_nquads=nq)
            request = txn.create_request(query=query,
                                         mutations=[mutation],
                                         commit_now=True)
            txn.commit()
            await self.reply(msg, orjson.dumps({"ok": True}))
        except Exception:
            traceback.print_exc()
            print(nq, query)
            txn.discard()
            await self.reply(msg, orjson.dumps({"ok": False}))

    def stop(self) -> None:
        self.fut_stop.set_result(True)

    async def run_until_done(self) -> None:
        await self.fut_stop
        await asyncio.gather(self.nc.drain())
예제 #3
0
class NATSHelper(object):
    def __init__(self,
                 nats_broker_url=None,
                 nats_src_topic=None,
                 nats_dst_topic=None):
        self.subscribe_id = None
        self.connected = False
        self.nats_client = NATS()

        self._get_config_from_env_var_('nats_broker_url', 'NATS_ENDPOINT',
                                       nats_broker_url, 'NATS broker')
        self._get_config_from_env_var_('nats_src_topic', 'NATS_SRC_TOPIC',
                                       nats_src_topic, 'NATS source topic')
        self._get_config_from_env_var_('nats_dst_topic', 'NATS_DST_TOPIC',
                                       nats_dst_topic,
                                       'NATS destination topic')

        logger.info("broker: {b}, src topic: {s}, dst_topic: {d}".format(
            b=self.nats_broker_url,
            s=self.nats_src_topic,
            d=self.nats_dst_topic))

    def __del__(self):
        self.close()

    def _get_config_from_env_var_(self, attr, var, default, what):
        if default is None:
            value = os.environ.get(var)
        else:
            value = default
        if not value:  # also check empty strings
            raise Exception(
                '{what} not provided in environment var {var}'.format(
                    what=what, var=var))
        setattr(self, attr, value)

    @staticmethod
    def payload_from_message(msg):
        """
        Convert input payload into an image
        """
        _msg = xi_iot_pb2.DataStreamMessage()
        _msg.ParseFromString(msg.data)
        return _msg.payload

    @staticmethod
    def message_from_payload(payload):
        """
        Convert image into output payload
        """
        msg = xi_iot_pb2.DataStreamMessage()
        msg.payload = payload
        return msg.SerializeToString()

    async def publish(self, payload):
        try:
            payload = self.message_from_payload(payload)
            # RFC: We could leverage `reply` topic as the destination topic which would not require NATS_DST_TOPIC to be provided
            # await nc.publish(reply, data)
            logger.info("Sending message to topic '{}'".format(
                self.nats_dst_topic))
            await self.nats_client.publish(self.nats_dst_topic, payload)
        except Exception as e:
            # Catch an display errors which are otherwise not shown
            logger.error("{}".format(e))
            raise

    async def connect(self, loop, message_handler_cb):
        # Define a helper function to unbox the message and catch errors for display
        async def receive_cb(msg):
            try:
                logger.info("Received a message!")
                payload = self.payload_from_message(msg)
                await message_handler_cb(payload)
            except Exception as e:
                # Catch an display errors which are otherwise not shown
                logger.error("{}".format(e))
                raise

        try:
            # This will return immediately if the server is not listening on the given URL
            await self.nats_client.connect(self.nats_broker_url, loop=loop)
            self.connected = True
            logger.info(
                "Connected to broker, subscribing to topic '{}'".format(
                    self.nats_src_topic))

            self.subscribe_id = await self.nats_client.subscribe(
                self.nats_src_topic, cb=receive_cb)
        except Exception as e:
            # Catch an display errors which are otherwise not shown
            logger.error("{}".format(e))
            raise

    def close(self):
        # Remove interest in subscription.
        loop = asyncio.get_event_loop()
        if self.subscribe_id is not None:
            loop.run_until_complete(
                self.nats_client.unsubscribe(self.subscribe_id))
            self.subscribe_id = None

        # Terminate connection to NATS.
        if self.nats_client is not None and self.connected:
            self.nats_client.drain()
            loop.run_until_complete(self.nats_client.close())
            self.nats_client = None
            self.connected = False