Beispiel #1
0
    async def test_new_indexstore(self, tempdir):
        if not INDEXING_ENABLED:
            pytest.skip("Indexing needs to be enabled to test this")

        from pantalaimon.index import Index, IndexStore
        loop = asyncio.get_event_loop()

        store = IndexStore("example", tempdir)

        store.add_event(self.test_event, TEST_ROOM, None, None)
        store.add_event(self.another_event, TEST_ROOM, None, None)
        await store.commit_events()

        assert store.event_in_store(self.test_event.event_id, TEST_ROOM)
        assert not store.event_in_store("FAKE", TEST_ROOM)

        result = await store.search("test",
                                    TEST_ROOM,
                                    after_limit=10,
                                    before_limit=10)
        pprint.pprint(result)

        assert len(result["results"]) == 1
        assert result["count"] == 1
        assert result["results"][0]["result"] == self.test_event.source
        assert (result["results"][0]["context"]["events_after"][0] ==
                self.another_event.source)
Beispiel #2
0
    def __init__(
        self,
        server_name,
        pan_store,
        pan_conf,
        homeserver,
        queue=None,
        user_id="",
        device_id="",
        store_path="",
        config=None,
        ssl=None,
        proxy=None,
        store_class=None,
    ):
        config = config or AsyncClientConfig(
            store=store_class or SqliteStore, store_name="pan.db"
        )
        super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)

        index_dir = os.path.join(store_path, server_name, user_id)

        try:
            os.makedirs(index_dir)
        except OSError:
            pass

        self.server_name = server_name
        self.pan_store = pan_store
        self.pan_conf = pan_conf

        if INDEXING_ENABLED:
            logger.info("Indexing enabled.")
            from pantalaimon.index import IndexStore

            self.index = IndexStore(self.user_id, index_dir)
        else:
            logger.info("Indexing disabled.")
            self.index = None

        self.task = None
        self.queue = queue

        # Those two events are mainly used for testing.
        self.new_fetch_task = asyncio.Event()
        self.fetch_loop_event = asyncio.Event()

        self.room_members_fetched = defaultdict(bool)

        self.send_semaphores = defaultdict(asyncio.Semaphore)
        self.send_decision_queues = dict()  # type: asyncio.Queue
        self.last_sync_token = None

        self.history_fetcher_task = None
        self.history_fetch_queue = asyncio.Queue()

        self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
        self.add_to_device_callback(
            self.key_request_cb, (RoomKeyRequest, RoomKeyRequestCancellation)
        )
        self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)

        if INDEXING_ENABLED:
            self.add_event_callback(
                self.store_message_cb,
                (
                    RoomMessageText,
                    RoomMessageMedia,
                    RoomEncryptedMedia,
                    RoomTopicEvent,
                    RoomNameEvent,
                ),
            )

        self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
        self.add_response_callback(self.sync_tasks, SyncResponse)
Beispiel #3
0
class PanClient(AsyncClient):
    """A wrapper class around a nio AsyncClient extending its functionality."""

    def __init__(
        self,
        server_name,
        pan_store,
        pan_conf,
        homeserver,
        queue=None,
        user_id="",
        device_id="",
        store_path="",
        config=None,
        ssl=None,
        proxy=None,
        store_class=None,
    ):
        config = config or AsyncClientConfig(
            store=store_class or SqliteStore, store_name="pan.db"
        )
        super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)

        index_dir = os.path.join(store_path, server_name, user_id)

        try:
            os.makedirs(index_dir)
        except OSError:
            pass

        self.server_name = server_name
        self.pan_store = pan_store
        self.pan_conf = pan_conf

        if INDEXING_ENABLED:
            logger.info("Indexing enabled.")
            from pantalaimon.index import IndexStore

            self.index = IndexStore(self.user_id, index_dir)
        else:
            logger.info("Indexing disabled.")
            self.index = None

        self.task = None
        self.queue = queue

        # Those two events are mainly used for testing.
        self.new_fetch_task = asyncio.Event()
        self.fetch_loop_event = asyncio.Event()

        self.room_members_fetched = defaultdict(bool)

        self.send_semaphores = defaultdict(asyncio.Semaphore)
        self.send_decision_queues = dict()  # type: asyncio.Queue
        self.last_sync_token = None

        self.history_fetcher_task = None
        self.history_fetch_queue = asyncio.Queue()

        self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
        self.add_to_device_callback(
            self.key_request_cb, (RoomKeyRequest, RoomKeyRequestCancellation)
        )
        self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)

        if INDEXING_ENABLED:
            self.add_event_callback(
                self.store_message_cb,
                (
                    RoomMessageText,
                    RoomMessageMedia,
                    RoomEncryptedMedia,
                    RoomTopicEvent,
                    RoomNameEvent,
                ),
            )

        self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
        self.add_response_callback(self.sync_tasks, SyncResponse)

    def store_message_cb(self, room, event):
        assert INDEXING_ENABLED

        display_name = room.user_name(event.sender)
        avatar_url = room.avatar_url(event.sender)

        if not room.encrypted and self.pan_conf.index_encrypted_only:
            return

        self.index.add_event(event, room.room_id, display_name, avatar_url)

    @property
    def unable_to_decrypt(self):
        """Room event signaling that the message couldn't be decrypted."""
        return {
            "type": "m.room.message",
            "content": {
                "msgtype": "m.text",
                "body": (
                    "** Unable to decrypt: The sender's device has not "
                    "sent us the keys for this message. **"
                ),
            },
        }

    async def send_message(self, message):
        """Send a thread message to the UI thread."""
        if self.queue:
            await self.queue.put(message)

    async def send_update_devices(self, devices):
        """Send a dictionary of devices to the UI thread."""
        dict_devices = defaultdict(dict)

        for user_devices in devices.values():
            for device in user_devices.values():
                # Turn the OlmDevice type into a dictionary, flatten the
                # keys dict and remove the deleted key/value.
                # Since all the keys and values are strings this also
                # copies them making it thread safe.
                device_dict = device.as_dict()
                device_dict = {**device_dict, **device_dict["keys"]}
                device_dict.pop("keys")
                display_name = device_dict.pop("display_name")
                device_dict["device_display_name"] = display_name
                dict_devices[device.user_id][device.id] = device_dict

        message = UpdateDevicesMessage(self.user_id, dict_devices)
        await self.send_message(message)

    async def send_update_device(self, device):
        """Send a single device to the UI thread to be updated."""
        await self.send_update_devices({device.user_id: {device.id: device}})

    def delete_fetcher_task(self, task):
        self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)

    async def fetcher_loop(self):
        assert INDEXING_ENABLED

        for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
            await self.history_fetch_queue.put(t)

        while True:
            self.fetch_loop_event.set()
            self.fetch_loop_event.clear()

            try:
                await asyncio.sleep(self.pan_conf.history_fetch_delay)
                fetch_task = await self.history_fetch_queue.get()

                try:
                    room = self.rooms[fetch_task.room_id]
                except KeyError:
                    # The room is missing from our client, we probably left the
                    # room.
                    self.delete_fetcher_task(fetch_task)
                    continue

                try:
                    logger.debug(
                        f"Fetching room history for {room.display_name} "
                        f"({room.room_id}), token {fetch_task.token}."
                    )
                    response = await self.room_messages(
                        fetch_task.room_id,
                        fetch_task.token,
                        limit=self.pan_conf.indexing_batch_size,
                    )
                except ClientConnectionError as e:
                    logger.debug("Error fetching room history: ", e)
                    await self.history_fetch_queue.put(fetch_task)

                # The chunk was empty, we're at the start of the timeline.
                if not response.chunk:
                    self.delete_fetcher_task(fetch_task)
                    continue

                for event in response.chunk:
                    if not isinstance(
                        event,
                        (
                            RoomMessageText,
                            RoomMessageMedia,
                            RoomEncryptedMedia,
                            RoomTopicEvent,
                            RoomNameEvent,
                        ),
                    ):
                        continue

                    display_name = room.user_name(event.sender)
                    avatar_url = room.avatar_url(event.sender)
                    self.index.add_event(event, room.room_id, display_name, avatar_url)

                last_event = response.chunk[-1]

                if not self.index.event_in_store(last_event.event_id, room.room_id):
                    # There may be even more events to fetch, add a new task to
                    # the queue.
                    task = FetchTask(room.room_id, response.end)
                    self.pan_store.replace_fetcher_task(
                        self.server_name, self.user_id, fetch_task, task
                    )
                    await self.history_fetch_queue.put(task)
                    self.new_fetch_task.set()
                    self.new_fetch_task.clear()
                else:
                    await self.index.commit_events()
                    self.delete_fetcher_task(fetch_task)

            except (asyncio.CancelledError, KeyboardInterrupt):
                return

    async def sync_tasks(self, response):
        if self.index:
            await self.index.commit_events()

        if self.last_sync_token == self.next_batch:
            return

        self.last_sync_token = self.next_batch

        self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)

        for room_id, room_info in response.rooms.join.items():
            if room_info.timeline.limited:
                room = self.rooms[room_id]

                if not room.encrypted and self.pan_conf.index_encrypted_only:
                    continue

                logger.info(
                    "Room {} had a limited timeline, queueing "
                    "room for history fetching.".format(room.display_name)
                )
                task = FetchTask(room_id, room_info.timeline.prev_batch)
                self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)

                await self.history_fetch_queue.put(task)
                self.new_fetch_task.set()
                self.new_fetch_task.clear()

    async def keys_query_cb(self, response):
        if response.changed:
            await self.send_update_devices(response.changed)

    async def undecrypted_event_cb(self, room, event):
        logger.info(
            "Unable to decrypt event from {} via {}.".format(
                event.sender, event.device_id
            )
        )

        if event.session_id not in self.outgoing_key_requests:
            logger.info("Requesting room key for undecrypted event.")

            # TODO we may want to retry this
            try:
                await self.request_room_key(event)
            except ClientConnectionError:
                pass

    async def key_request_cb(self, event):
        if isinstance(event, RoomKeyRequest):
            logger.info(
                f"{event.sender} via {event.requesting_device_id} has "
                f" requested room keys from  us."
            )

            message = KeyRequestMessage(self.user_id, event)
            await self.send_message(message)

        elif isinstance(event, RoomKeyRequestCancellation):
            logger.info(
                f"{event.sender} via {event.requesting_device_id} has "
                f" canceled its key request."
            )

            message = KeyRequestMessage(self.user_id, event)
            await self.send_message(message)

        else:
            assert False

    async def key_verification_cb(self, event):
        logger.info("Received key verification event: {}".format(event))
        if isinstance(event, KeyVerificationStart):
            logger.info(
                f"{event.sender} via {event.from_device} has started "
                f"a key verification process."
            )

            message = InviteSasSignal(
                self.user_id, event.sender, event.from_device, event.transaction_id
            )

            await self.send_message(message)

        elif isinstance(event, KeyVerificationKey):
            sas = self.key_verifications.get(event.transaction_id, None)
            if not sas:
                return

            device = sas.other_olm_device
            emoji = sas.get_emoji()

            message = ShowSasSignal(
                self.user_id, device.user_id, device.id, sas.transaction_id, emoji
            )

            await self.send_message(message)

        elif isinstance(event, KeyVerificationMac):
            sas = self.key_verifications.get(event.transaction_id, None)
            if not sas:
                return
            device = sas.other_olm_device

            if sas.verified:
                await self.send_message(
                    SasDoneSignal(
                        self.user_id, device.user_id, device.id, sas.transaction_id
                    )
                )
                await self.send_update_device(device)

    def start_loop(self, loop_sleep_time=100):
        """Start a loop that runs forever and keeps on syncing with the server.

        The loop can be stopped with the stop_loop() method.
        """
        assert not self.task

        logger.info(f"Starting sync loop for {self.user_id}")

        loop = asyncio.get_event_loop()

        if INDEXING_ENABLED:
            self.history_fetcher_task = loop.create_task(self.fetcher_loop())

        timeout = 30000
        sync_filter = {"room": {"state": {"lazy_load_members": True}}}
        next_batch = self.pan_store.load_token(self.server_name, self.user_id)
        self.last_sync_token = next_batch

        # We don't store any room state so initial sync needs to be with the
        # full_state parameter. Subsequent ones are normal.
        task = loop.create_task(
            self.sync_forever(
                timeout,
                sync_filter,
                full_state=True,
                since=next_batch,
                loop_sleep_time=loop_sleep_time,
            )
        )
        self.task = task

        return task

    async def start_sas(self, message, device):
        try:
            await self.start_key_verification(device)
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    "m.ok",
                    "Successfully started the key verification request",
                )
            )
        except ClientConnectionError as e:
            await self.send_message(
                DaemonResponse(
                    message.message_id, self.user_id, "m.connection_error", str(e)
                )
            )

    async def accept_sas(self, message):
        user_id = message.user_id
        device_id = message.device_id

        sas = self.get_active_sas(user_id, device_id)

        if not sas:
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    Sas._txid_error[0],
                    Sas._txid_error[1],
                )
            )
            return

        try:
            await self.accept_key_verification(sas.transaction_id)
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    "m.ok",
                    "Successfully accepted the key verification request",
                )
            )
        except LocalProtocolError as e:
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    Sas._unexpected_message_error[0],
                    str(e),
                )
            )
        except ClientConnectionError as e:
            await self.send_message(
                DaemonResponse(
                    message.message_id, self.user_id, "m.connection_error", str(e)
                )
            )

    async def cancel_sas(self, message):
        user_id = message.user_id
        device_id = message.device_id

        sas = self.get_active_sas(user_id, device_id)

        if not sas:
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    Sas._txid_error[0],
                    Sas._txid_error[1],
                )
            )
            return

        try:
            await self.cancel_key_verification(sas.transaction_id)
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    "m.ok",
                    "Successfully canceled the key verification request",
                )
            )
        except ClientConnectionError as e:
            await self.send_message(
                DaemonResponse(
                    message.message_id, self.user_id, "m.connection_error", str(e)
                )
            )

    async def confirm_sas(self, message):
        user_id = message.user_id
        device_id = message.device_id

        sas = self.get_active_sas(user_id, device_id)

        if not sas:
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    Sas._txid_error[0],
                    Sas._txid_error[1],
                )
            )
            return

        try:
            await self.confirm_short_auth_string(sas.transaction_id)
        except ClientConnectionError as e:
            await self.send_message(
                DaemonResponse(
                    message.message_id, self.user_id, "m.connection_error", str(e)
                )
            )

            return

        device = sas.other_olm_device

        if sas.verified:
            await self.send_update_device(device)
            await self.send_message(
                SasDoneSignal(
                    self.user_id, device.user_id, device.id, sas.transaction_id
                )
            )
        else:
            await self.send_message(
                DaemonResponse(
                    message.message_id,
                    self.user_id,
                    "m.ok",
                    f"Waiting for {device.user_id} to confirm.",
                )
            )

    async def handle_key_request_message(self, message):
        if isinstance(message, ContinueKeyShare):
            continued = False
            for share in self.get_active_key_requests(
                message.user_id, message.device_id
            ):

                continued = True

                if not self.continue_key_share(share):
                    await self.send_message(
                        DaemonResponse(
                            message.message_id,
                            self.user_id,
                            "m.error",
                            (
                                f"Unable to continue the key sharing for "
                                f"{message.user_id} via {message.device_id}: The "
                                f"device is still not verified."
                            ),
                        )
                    )
                    return

            if continued:
                try:
                    await self.send_to_device_messages()
                except ClientConnectionError:
                    # We can safely ignore this since this will be retried
                    # after the next sync in the sync_forever method.
                    pass

                response = (
                    f"Succesfully continued the key requests from "
                    f"{message.user_id} via {message.device_id}"
                )
                ret = "m.ok"
            else:
                response = (
                    f"No active key requests from {message.user_id} "
                    f"via {message.device_id} found."
                )
                ret = "m.error"

            await self.send_message(
                DaemonResponse(message.message_id, self.user_id, ret, response)
            )

        elif isinstance(message, CancelKeyShare):
            cancelled = False

            for share in self.get_active_key_requests(
                message.user_id, message.device_id
            ):
                cancelled = self.cancel_key_share(share)

            if cancelled:
                response = (
                    f"Succesfully cancelled key requests from "
                    f"{message.user_id} via {message.device_id}"
                )
                ret = "m.ok"
            else:
                response = (
                    f"No active key requests from {message.user_id} "
                    f"via {message.device_id} found."
                )
                ret = "m.error"

            await self.send_message(
                DaemonResponse(message.message_id, self.user_id, ret, response)
            )

    async def loop_stop(self):
        """Stop the client loop."""
        logger.info("Stopping the sync loop")

        if self.task and not self.task.done():
            self.task.cancel()

            try:
                await self.task
            except KeyboardInterrupt:
                pass

            self.task = None

        if self.history_fetcher_task and not self.history_fetcher_task.done():
            self.history_fetcher_task.cancel()

            try:
                await self.history_fetcher_task
            except KeyboardInterrupt:
                pass

            self.history_fetcher_task = None

        if isinstance(self.store, SqliteQueueDatabase):
            self.store.close()

        self.history_fetch_queue = asyncio.Queue()

    def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
        # type: (Dict[Any, Any], Optional[str], bool) -> (bool)
        event = Event.parse_encrypted_event(event_dict)

        if not isinstance(event, MegolmEvent):
            logger.warn(
                "Encrypted event is not a megolm event:"
                "\n{}".format(pformat(event_dict))
            )
            return False

        if not event.room_id:
            event.room_id = room_id

        try:
            decrypted_event = self.decrypt_event(event)
            logger.info("Decrypted event: {}".format(decrypted_event))

            event_dict.update(decrypted_event.source)
            event_dict["decrypted"] = True
            event_dict["verified"] = decrypted_event.verified

            return True

        except EncryptionError as error:
            logger.warn(error)

            if ignore_failures:
                event_dict.update(self.unable_to_decrypt)
            else:
                raise

            return False

    def decrypt_messages_body(self, body, ignore_failures=True):
        # type: (Dict[Any, Any], bool) -> Dict[Any, Any]
        """Go through a messages response and decrypt megolm encrypted events.

        Args:
            body (Dict[Any, Any]): The dictionary of a Sync response.

        Returns the json response with decrypted events.
        """
        if "chunk" not in body:
            return body

        logger.info("Decrypting room messages")

        for event in body["chunk"]:
            if "type" not in event:
                continue

            if event["type"] != "m.room.encrypted":
                logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
                continue

            self.pan_decrypt_event(event, ignore_failures=ignore_failures)

        return body

    def handle_to_device_from_sync_body(self, body):
        to_device_events = body.get("to_device")

        if not to_device_events or "events" not in to_device_events:
            return

        for event in to_device_events["events"]:
            event = ToDeviceEvent.parse_encrypted_event(event)

            if not isinstance(event, ToDeviceEvent):
                continue

            self.olm.handle_to_device_event(event)

    def decrypt_sync_body(self, body, ignore_failures=True):
        # type: (Dict[Any, Any], bool) -> Dict[Any, Any]
        """Go through a json sync response and decrypt megolm encrypted events.

        Args:
            body (Dict[Any, Any]): The dictionary of a Sync response.

        Returns the json response with decrypted events.
        """
        logger.info("Decrypting sync")

        self.handle_to_device_from_sync_body(body)

        for room_id, room_dict in body["rooms"]["join"].items():
            try:
                if not self.rooms[room_id].encrypted:
                    logger.info(
                        "Room {} is not encrypted skipping...".format(
                            self.rooms[room_id].display_name
                        )
                    )
                    continue
            except KeyError:
                logger.info("Unknown room {} skipping...".format(room_id))
                continue

            for event in room_dict["timeline"]["events"]:
                if "type" not in event:
                    continue

                if event["type"] != "m.room.encrypted":
                    continue

                self.pan_decrypt_event(event, room_id, ignore_failures)

        return body

    async def search(self, search_terms):
        # type: (Dict[Any, Any]) -> Dict[Any, Any]
        assert INDEXING_ENABLED

        state_cache = dict()

        async def add_context(event_dict, room_id, event_id, include_state):
            try:
                context = await self.room_context(room_id, event_id, limit=0)
            except ClientConnectionError:
                return

            if isinstance(context, RoomContextError):
                return

            if include_state:
                state_cache[room_id] = [e.source for e in context.state]

            event_dict["context"]["start"] = context.start
            event_dict["context"]["end"] = context.end

        search_terms = search_terms["search_categories"]["room_events"]

        term = search_terms["search_term"]
        search_filter = search_terms["filter"]
        limit = search_filter.get("limit", 10)

        if limit <= 0:
            raise InvalidLimit("The limit must be strictly greater than 0.")

        rooms = search_filter.get("rooms", [])

        room_id = rooms[0] if len(rooms) == 1 else None

        order_by = search_terms.get("order_by")

        if order_by not in ["rank", "recent"]:
            raise InvalidOrderByError(f"Invalid order by: {order_by}")

        order_by_recent = order_by == "recent"

        before_limit = 0
        after_limit = 0
        include_profile = False

        event_context = search_terms.get("event_context")
        include_state = search_terms.get("include_state")

        if event_context:
            before_limit = event_context.get("before_limit", 5)
            after_limit = event_context.get("before_limit", 5)

        if before_limit < 0 or after_limit < 0:
            raise InvalidLimit(
                "Invalid context limit, the limit must be a " "positive number"
            )

        response_dict = await self.index.search(
            term,
            room=room_id,
            max_results=limit,
            order_by_recent=order_by_recent,
            include_profile=include_profile,
            before_limit=before_limit,
            after_limit=after_limit,
        )

        if (event_context or include_state) and self.pan_conf.search_requests:
            for event_dict in response_dict["results"]:
                await add_context(
                    event_dict,
                    event_dict["result"]["room_id"],
                    event_dict["result"]["event_id"],
                    include_state,
                )

        if include_state:
            response_dict["state"] = state_cache

        return {"search_categories": {"room_events": response_dict}}