Example #1
0
class WhoIsHook(immp.Hook):
    """
    Hook to provide generic lookup of user profiles across one or more identity providers.
    """

    schema = immp.Schema({
        "identities": [str],
        immp.Optional("public", False): bool
    })

    _identities = immp.ConfigProperty([IdentityProvider])

    @command("who", parser=CommandParser.none)
    async def who(self, msg, name):
        """
        Recall a known identity and all of its links.
        """
        if self.config["public"]:
            providers = self._identities
        else:
            tasks = (provider.identity_from_user(msg.user)
                     for provider in self._identities)
            providers = [
                identity.provider for identity in await gather(*tasks)
                if identity
            ]
        if providers:
            if name[0].mention:
                user = name[0].mention
                tasks = (provider.identity_from_user(user)
                         for provider in providers)
            else:
                tasks = (provider.identity_from_name(str(name))
                         for provider in providers)
            identities = list(filter(None, await gather(*tasks)))
            links = {
                link
                for identity in identities for link in identity.links
            }
            if links:
                text = name.clone()
                for segment in text:
                    segment.bold = True
                text.append(immp.Segment(" may appear as:"))
                for user in sorted(links,
                                   key=lambda user: user.plug.network_name):
                    text.append(immp.Segment("\n"))
                    text.append(
                        immp.Segment("({}) ".format(user.plug.network_name)))
                    if user.link:
                        text.append(
                            immp.Segment(user.real_name or user.username,
                                         link=user.link))
                    elif user.real_name and user.username:
                        text.append(
                            immp.Segment("{} [{}]".format(
                                user.real_name, user.username)))
                    else:
                        text.append(
                            immp.Segment(user.real_name or user.username))
            else:
                text = "{} Name not in use".format(CROSS)
        else:
            text = "{} Not identified".format(CROSS)
        await msg.channel.send(immp.Message(text=text))
Example #2
0
class LocalIdentityHook(immp.Hook, AccessPredicate, IdentityProvider):
    """
    Hook for managing physical users with multiple logical links across different plugs.  This
    effectively provides self-service identities, as opposed to being provided externally.
    """

    schema = immp.Schema({
        immp.Optional("instance"): immp.Nullable(int),
        "plugs": [str],
        immp.Optional("multiple", True): bool
    })

    _plugs = immp.ConfigProperty([immp.Plug])

    def __init__(self, name, config, host):
        super().__init__(name, config, host)
        self.db = None

    async def start(self):
        if not self.config["instance"]:
            # Find a non-conflicting number and assign it.
            codes = {
                hook.config["instance"]
                for hook in self.host.hooks.values()
                if isinstance(hook, self.__class__)
            }
            code = 1
            while code in codes:
                code += 1
            log.debug("Assigning instance code %d to hook %r", code, self.name)
            self.config["instance"] = code
        self.db = self.host.resources[DatabaseHook].db
        self.db.create_tables([IdentityGroup, IdentityLink, IdentityRole],
                              safe=True)

    def get(self, name):
        """
        Retrieve the identity group using the given name.

        Args:
            name (str):
                Existing name to query.

        Returns:
            .IdentityGroup:
                Linked identity, or ``None`` if not linked.
        """
        try:
            return IdentityGroup.select_links().where(
                IdentityGroup.name == name).get()
        except IdentityGroup.DoesNotExist:
            return None

    def find(self, user):
        """
        Retrieve the identity that contains the given user, if one exists.

        Args:
            user (.User):
                Existing user to query.

        Returns:
            .IdentityGroup:
                Linked identity, or ``None`` if not linked.
        """
        if not user or user.plug not in self._plugs:
            return None
        try:
            return (IdentityGroup.select_links().where(
                IdentityGroup.instance == self.config["instance"],
                IdentityLink.network == user.plug.network_id,
                IdentityLink.user == user.id).get())
        except IdentityGroup.DoesNotExist:
            return None

    async def channel_access(self, channel, user):
        return bool(
            IdentityLink.get(network=user.plug.network_id, user=user.id))

    async def identity_from_name(self, name):
        group = self.get(name)
        return await group.to_identity(self.host, self) if group else None

    async def identity_from_user(self, user):
        group = self.find(user)
        return await group.to_identity(self.host, self) if group else None

    def _test(self, channel, user):
        return channel.plug in self._plugs

    @command("id-add", scope=CommandScope.private, test=_test)
    async def add(self, msg, name, pwd):
        """
        Create a new identity, or link to an existing one from a second user.
        """
        if not msg.user or msg.user.plug not in self._plugs:
            return
        if self.find(msg.user):
            text = "{} Already identified".format(CROSS)
        else:
            pwd = IdentityGroup.hash(pwd)
            exists = False
            try:
                group = IdentityGroup.get(instance=self.config["instance"],
                                          name=name)
                exists = True
            except IdentityGroup.DoesNotExist:
                group = IdentityGroup.create(instance=self.config["instance"],
                                             name=name,
                                             pwd=pwd)
            if exists and not group.pwd == pwd:
                text = "{} Password incorrect".format(CROSS)
            elif not self.config["multiple"] and any(
                    link.network == msg.user.plug.network_id
                    for link in group.links):
                text = "{} Already identified on {}".format(
                    CROSS, msg.user.plug.network_name)
            else:
                IdentityLink.create(group=group,
                                    network=msg.user.plug.network_id,
                                    user=msg.user.id)
                text = "{} {}".format(TICK, "Added" if exists else "Claimed")
        await msg.channel.send(immp.Message(text=text))

    @command("id-rename", scope=CommandScope.private, test=_test)
    async def rename(self, msg, name):
        """
        Rename the current identity.
        """
        if not msg.user:
            return
        group = self.find(msg.user)
        if not group:
            text = "{} Not identified".format(CROSS)
        elif group.name == name:
            text = "{} No change".format(TICK)
        elif IdentityGroup.select().where(
                IdentityGroup.instance == self.config["instance"],
                IdentityGroup.name == name).exists():
            text = "{} Name already in use".format(CROSS)
        else:
            group.name = name
            group.save()
            text = "{} Claimed".format(TICK)
        await msg.channel.send(immp.Message(text=text))

    @command("id-password", scope=CommandScope.private, test=_test)
    async def password(self, msg, pwd):
        """
        Update the password for the current identity.
        """
        if not msg.user:
            return
        group = self.find(msg.user)
        if not group:
            text = "{} Not identified".format(CROSS)
        else:
            group.pwd = IdentityGroup.hash(pwd)
            group.save()
            text = "{} Changed".format(TICK)
        await msg.channel.send(immp.Message(text=text))

    @command("id-reset", scope=CommandScope.private, test=_test)
    async def reset(self, msg):
        """
        Delete the current identity and all linked users.
        """
        if not msg.user:
            return
        group = self.find(msg.user)
        if not group:
            text = "{} Not identified".format(CROSS)
        else:
            group.delete_instance()
            text = "{} Reset".format(TICK)
        await msg.channel.send(immp.Message(text=text))

    @command("id-role",
             scope=CommandScope.private,
             role=CommandRole.admin,
             test=_test)
    async def role(self, msg, name, role=None):
        """
        List roles assigned to an identity, or add/remove a given role.
        """
        try:
            group = IdentityGroup.get(instance=self.config["instance"],
                                      name=name)
        except IdentityGroup.DoesNotExist:
            text = "{} Name not registered".format(CROSS)
        else:
            if role:
                count = IdentityRole.delete().where(
                    IdentityRole.group == group,
                    IdentityRole.role == role).execute()
                if count:
                    text = "{} Removed".format(TICK)
                else:
                    IdentityRole.create(group=group, role=role)
                    text = "{} Added".format(TICK)
            else:
                roles = IdentityRole.select().where(
                    IdentityRole.group == group)
                if roles:
                    labels = [role.role for role in roles]
                    text = "Roles for {}: {}".format(name, ", ".join(labels))
                else:
                    text = "No roles for {}.".format(name)
        await msg.channel.send(immp.Message(text=text))
Example #3
0
class ChannelAccessHook(immp.Hook, AccessPredicate):
    """
    Hook for controlling membership of, and joins to, secure channels.
    """

    schema = immp.Schema({immp.Optional("hooks", dict): {str: immp.Nullable([str])},
                          immp.Optional("exclude", dict): {str: [str]},
                          immp.Optional("joins", True): bool,
                          immp.Optional("startup", False): bool,
                          immp.Optional("passive", False): bool,
                          immp.Optional("default", True): bool})

    hooks = immp.ConfigProperty({AccessPredicate: [immp.Channel]})

    @property
    def channels(self):
        inverse = defaultdict(list)
        for hook, channels in self.hooks.items():
            if not channels:
                continue
            for channel in channels:
                inverse[channel].append(hook)
        return inverse

    # This hook acts as an example predicate to block all access.

    async def channel_access_multi(self, channels, users):
        return [], list(product(channels, users))

    async def channel_access(self, channel, user):
        return False

    async def verify(self, members=None):
        """
        Perform verification of each user in each channel, for all configured access predicates.
        Users who are denied access by any predicate will be removed, unless passive mode is set.

        Args:
            members ((.Channel, .User set) dict):
                Mapping from target channels to a subset of users pending verification.

                If ``None`` is given for a channel's set of users, all members present in the
                channel will be verified.  If ``members`` itself is ``None``, access checks will be
                run against all configured channels.
        """
        everywhere = set()
        grouped = {}
        for hook, scope in self.hooks.items():
            interested = await hook.access_channels()
            if scope and interested:
                log.debug("Hook %r using scope and own list", hook)
                wanted = set(interested).intersection(scope)
            elif scope or interested:
                log.debug("Hook %r using %s", hook, "scope" if scope else "own list")
                wanted = set(scope or interested)
            else:
                log.warning("Hook %r has no declared channels for access control", hook)
                continue
            if members is not None:
                wanted.intersection_update(members)
            if wanted:
                everywhere.update(wanted)
                grouped[hook] = wanted
            else:
                log.debug("Skipping hook %r as member filter doesn't overlap", hook)
        targets = defaultdict(set)
        members = members or {}
        for channel in everywhere:
            users = members.get(channel)
            try:
                current = await channel.members()
            except Exception:
                log.warning("Failed to retrieve members for channel %r", channel, exc_info=True)
                continue
            for user in users or current or ():
                if current and user not in current:
                    log.debug("Skipping non-member user %r", user)
                elif user.id in self.config["exclude"].get(user.plug.name, []):
                    log.debug("Skipping excluded user %r", user)
                elif await user.is_system():
                    log.debug("Skipping system user %r", user)
                else:
                    targets[channel].add(user)
        hooks = []
        tasks = []
        for hook, channels in grouped.items():
            known = {channel: users for channel, users in targets.items() if users}
            log.debug("Requesting decisions from %r: %r", hook, set(known))
            hooks.append(hook)
            tasks.append(ensure_future(hook.channel_access_multi(known)))
        allowed = set()
        denied = set()
        for hook, result in zip(hooks, await gather(*tasks, return_exceptions=True)):
            if isinstance(result, Exception):
                log.warning("Failed to verify channel access with hook %r",
                            hook.name, exc_info=result)
                continue
            hook_allowed, hook_denied = result
            allowed.update(hook_allowed)
            if hook_denied:
                log.debug("Hook %r denied %d user-channel pair(s)", hook.name, len(hook_denied))
                denied.update(hook_denied)
        removals = defaultdict(set)
        for channel, users in targets.items():
            for user in users:
                pair = (channel, user)
                if pair in denied:
                    allow = False
                elif pair in allowed:
                    allow = True
                else:
                    allow = self.config["default"]
                if allow:
                    log.debug("Allowing access to %r for %r", channel, user)
                else:
                    log.debug("Denying access to %r for %r", channel, user)
                    removals[channel].add(user)
        active = not self.config["passive"]
        for channel, refused in removals.items():
            log.debug("%s %d user(s) from %r: %r", "Removing" if active else "Would remove",
                      len(refused), channel, refused)
            if active:
                await channel.remove_multi(refused)

    async def _startup_check(self):
        log.debug("Running startup access checks")
        await self.verify()
        log.debug("Finished startup access checks")

    def on_ready(self):
        if self.config["startup"]:
            ensure_future(self._startup_check())

    async def on_receive(self, sent, source, primary):
        await super().on_receive(sent, source, primary)
        if self.config["joins"] and primary and sent == source and source.joined:
            await self.verify({sent.channel: source.joined})
Example #4
0
class _SyncHookBase(immp.Hook):

    schema = immp.Schema({"channels": {str: [str]},
                          immp.Optional("joins", False): bool,
                          immp.Optional("renames", False): bool,
                          immp.Optional("identities"): immp.Nullable(str),
                          immp.Optional("reset-author", False): bool,
                          immp.Optional("name-format"): immp.Nullable(str),
                          immp.Optional("strip-name-emoji", False): bool})

    _identities = immp.ConfigProperty(IdentityProvider)

    def _accept(self, msg):
        if not self.config["joins"] and (msg.joined or msg.left):
            log.debug("Not syncing join/part message: %r", msg.id)
            return False
        if not self.config["renames"] and msg.title:
            log.debug("Not syncing rename message: %r", msg.id)
            return False
        return True

    def _replace_recurse(self, msg, func, *args):
        # Switch out entire messages for copies or replacements.
        if msg.reply_to:
            msg.reply_to = func(msg.reply_to, *args)
        attachments = []
        for attach in msg.attachments:
            if isinstance(attach, immp.Message):
                attachments.append(func(attach, *args))
            else:
                attachments.append(attach)
        msg.attachments = attachments
        return msg

    async def _alter_recurse(self, msg, func, *args):
        # Alter properties on existing cloned message objects.
        await func(msg, *args)
        if msg.reply_to:
            await func(msg.reply_to, *args)
        for attach in msg.attachments:
            if isinstance(attach, immp.Message):
                await func(attach, *args)

    async def _rename_user(self, user, channel):
        # Use name-format or identities to render a suitable author real name.
        renamed = name = identity = None
        if not self.config["reset-author"]:
            renamed = user
        if self._identities:
            try:
                identity = await self._identities.identity_from_user(user)
            except Exception as e:
                log.warning("Failed to retrieve identity information for %r", user,
                            exc_info=e)
        if self.config["name-format"]:
            if not Template:
                raise immp.PlugError("'jinja2' module not installed")
            title = await channel.title() if channel else None
            context = {"user": user, "identity": identity, "channel": title}
            name = Template(self.config["name-format"]).render(**context)
            if not name and self.config["reset-author"]:
                user = None
        elif self.config["reset-author"]:
            user = None
        elif identity:
            name = "{} ({})".format(user.real_name or user.username, identity.name)
        elif self.config["strip-name-emoji"] and user:
            name = user.real_name or user.username
        if not name:
            return user
        if self.config["strip-name-emoji"]:
            if not EMOJI_REGEX:
                raise immp.PlugError("'emoji' module not installed")
            name = EMOJI_REGEX.sub(_emoji_replace, name).strip()
        if renamed:
            log.debug("Replacing real name: %r -> %r", renamed.real_name, name)
            renamed = copy(renamed)
            renamed.real_name = name
        else:
            log.debug("Adding real name: %r", name)
            renamed = immp.User(real_name=name)
        return renamed

    async def _alter_name(self, msg):
        channel = msg.channel if isinstance(msg, immp.Receipt) else None
        msg.user = await self._rename_user(msg.user, channel)

    async def _alter_identities(self, msg, channel):
        # Replace mentions for identified users in the target channel.
        if not msg.text:
            return
        msg.text = msg.text.clone()
        for segment in msg.text:
            user = segment.mention
            if not user or user.plug == channel.plug:
                # No mention or already matches plug, nothing to do.
                continue
            identity = None
            if self.config["identities"]:
                try:
                    identity = await self._identities.identity_from_user(user)
                except Exception as e:
                    log.warning("Failed to retrieve identity information for %r", user, exc_info=e)
            # Try to find an identity corresponding to the target plug.
            links = identity.links if identity else []
            for user in links:
                if user.plug == channel.plug:
                    log.debug("Replacing mention: %r -> %r", user, user)
                    segment.mention = user
                    break
            else:
                # Fallback case: replace mention with a link to the user's profile.
                if user.link:
                    log.debug("Adding fallback mention link: %r -> %r", user, user.link)
                    segment.link = user.link
                else:
                    log.debug("Removing foreign mention: %r", user)
                segment.mention = None
            # Perform name substitution on the mention text.
            if self.config["name-format"]:
                at = "@" if segment.text.startswith("@") else ""
                renamed = await self._rename_user(user, channel)
                segment.text = "{}{}".format(at, renamed.real_name)

    async def _send(self, channel, msg):
        try:
            ids = await channel.send(msg)
            log.debug("Synced IDs in %r: %r", channel, ids)
            return (channel, ids)
        except Exception:
            log.exception("Failed to relay message to channel: %r", channel)
            return (channel, [])
Example #5
0
class ForwardHook(_SyncHookBase):
    """
    Hook to propagate messages from a source channel to one or more destination channels.
    """

    schema = immp.Schema(
        {
            immp.Optional("users"): immp.Nullable([str]),
            immp.Optional("groups", dict): {
                str: [str]
            }
        }, _SyncHookBase.schema)

    _channels = immp.ConfigProperty({immp.Channel: [immp.Channel]})
    _groups = immp.ConfigProperty({immp.Group: [immp.Channel]})

    async def _targets(self, channel):
        targets = set()
        if channel in self._channels:
            targets.update(self._channels[channel])
        for group, channels in self._groups.items():
            if await group.has_channel(channel):
                targets.update(channels)
        return targets

    async def send(self, msg, channels):
        """
        Send a message to all channels in this forwarding group.

        Args:
            msg (.Message):
                External message to push.
            channels (.Channel list):
                Set of target channels to forward the message to.
        """
        queue = []
        clone = msg.clone()
        await self._alter_recurse(clone, self._alter_name)
        for synced in channels:
            local = clone.clone()
            await self._alter_recurse(local, self._alter_identities, synced)
            queue.append(self._send(synced, local))
        # Send all the messages in parallel.
        await gather(*queue)

    def _accept(self, msg, id_):
        if not super()._accept(msg, id_):
            return False
        if self.config["users"] is not None:
            if not msg.user or msg.user.id not in self.config["users"]:
                log.debug("Not syncing message from non-whitelisted user: %r",
                          msg.user.id)
                return False
        return True

    async def on_receive(self, sent, source, primary):
        await super().on_receive(sent, source, primary)
        if not primary or not self._accept(source, sent.id):
            return
        targets = await self._targets(sent.channel)
        if targets:
            await self.send(source, targets)
Example #6
0
class SyncHook(_SyncHookBase):
    """
    Hook to propagate messages between two or more channels.

    Attributes:
        plug (.SyncPlug):
            Virtual plug for this sync, if configured.
    """

    schema = immp.Schema(
        {
            immp.Optional("edits", True): bool,
            immp.Optional("joins", True): bool,
            immp.Optional("renames", True): bool,
            immp.Optional("plug"): immp.Nullable(str),
            immp.Optional("titles", dict): {
                str: str
            }
        }, _SyncHookBase.schema)

    user = immp.User(real_name="Sync", suggested=True)

    def __init__(self, name, config, host):
        super().__init__(name, config, host)
        # Message cache, stores IDs of all synced messages by channel.
        self._cache = SyncCache(self)
        # Hook lock, to put a hold on retrieving messages whilst a send is in progress.
        self._lock = BoundedSemaphore()
        # Add a virtual plug to the host, for external subscribers.
        if self.config["plug"]:
            log.debug("Creating virtual plug: %r", self.config["plug"])
            self.plug = SyncPlug(self.config["plug"], self, host)
            host.add_plug(self.plug)
            for label in self.config["channels"]:
                host.add_channel(label, immp.Channel(self.plug, label))
        else:
            self.plug = None
        self._db = False

    def on_load(self):
        try:
            self.host.resources[AsyncDatabaseHook].add_models(SyncBackRef)
        except KeyError:
            self._db = False
        else:
            self._db = True

    channels = immp.ConfigProperty({None: [immp.Channel]})

    def label_for_channel(self, channel):
        labels = []
        for label, channels in self.channels.items():
            if channel in channels:
                labels.append(label)
        if not labels:
            raise immp.ConfigError("Channel {} not bridged".format(
                repr(channel)))
        elif len(labels) > 1:
            raise immp.ConfigError("Channel {} defined more than once".format(
                repr(channel)))
        else:
            return labels[0]

    def _test(self, channel, user):
        return any(channel in channels for channels in self.channels.values())

    @command("sync-members", test=_test)
    async def members(self, msg):
        """
        List all members of the current conversation, across all channels.
        """
        members = defaultdict(list)
        missing = False
        for synced in self.channels[msg.channel.source]:
            local = (await synced.members())
            if local:
                members[synced.plug.network_name] += local
            else:
                missing = True
        if not members:
            return
        text = immp.RichText([immp.Segment("Members of this conversation:")])
        for network in sorted(members):
            text.append(immp.Segment("\n{}".format(network), bold=True))
            for member in sorted(
                    members[network],
                    key=lambda member: member.real_name or member.username):
                name = member.real_name or member.username
                text.append(immp.Segment("\n"))
                if member.link:
                    text.append(immp.Segment(name, link=member.link))
                elif member.real_name and member.username:
                    text.append(
                        immp.Segment("{} [{}]".format(name, member.username)))
                else:
                    text.append(immp.Segment(name))
        if missing:
            text.append(immp.Segment("\n"),
                        immp.Segment("(list may be incomplete)"))
        await msg.channel.send(immp.Message(user=self.user, text=text))

    @command("sync-list", test=_test)
    async def list(self, msg):
        """
        List all channels connected to this conversation.
        """
        text = immp.RichText([immp.Segment("Channels in this sync:")])
        for synced in self.channels[msg.channel.source]:
            text.append(immp.Segment("\n{}".format(synced.plug.network_name)))
            title = await synced.title()
            if title:
                text.append(immp.Segment(": {}".format(title)))
        await msg.channel.send(immp.Message(user=self.user, text=text))

    async def send(self, label, msg, origin=None, ref=None, update=False):
        """
        Send a message to all channels in this synced group.

        Args:
            label (str):
                Bridge that defines the underlying synced channels to send to.
            msg (.Message):
                External message to push.  This should be the source copy when syncing a message
                from another channel.
            origin (.Receipt):
                Raw message that triggered this sync; if set and part of the sync, it will be
                skipped (used to avoid retransmitting a message we just received).  This should be
                the plug-native copy of a message when syncing from another channel.
            ref (.SyncRef):
                Existing sync reference, if message has been partially synced.
            update (bool):
                ``True`` to force resending an updated message to all synced channels.
        """
        base = immp.Message(text=msg.text,
                            user=msg.user,
                            edited=msg.edited,
                            action=msg.action,
                            reply_to=msg.reply_to,
                            joined=msg.joined,
                            left=msg.left,
                            title=msg.title,
                            attachments=msg.attachments,
                            raw=msg)
        queue = []
        for synced in self.channels[label]:
            if origin and synced == origin.channel:
                continue
            elif not update and ref and ref.ids[synced]:
                log.debug("Skipping already-synced target channel %r: %r",
                          synced, ref)
                continue
            local = base.clone()
            await self._replace_recurse(local, self._replace_ref, synced)
            await self._alter_recurse(local, self._alter_identities, synced)
            await self._alter_recurse(local, self._alter_name)
            queue.append(self._send(synced, local))
        # Just like with plugs, when sending a new (external) message to all channels in a sync, we
        # need to wait for all plugs to complete and have their IDs cached before processing any
        # further messages.
        async with self._lock:
            all_receipts = dict(await gather(*queue))
            ids = {
                channel: [receipt.id for receipt in receipts]
                for channel, receipts in all_receipts.items()
            }
            if ref:
                ref.ids.update(ids)
            else:
                ref = SyncRef(ids, source=msg, origin=origin)
            await self._cache.add(ref)
        # Push a copy of the message to the sync channel, if running.
        if self.plug:
            sent = immp.SentMessage(id_=ref.key,
                                    channel=immp.Channel(self.plug, label),
                                    text=msg.text,
                                    user=msg.user,
                                    action=msg.action,
                                    reply_to=msg.reply_to,
                                    joined=msg.joined,
                                    left=msg.left,
                                    title=msg.title,
                                    attachments=msg.attachments,
                                    raw=msg)
            self.plug.queue(sent)
        return ref

    async def delete(self, ref, sent=None):
        queue = []
        for channel, ids in ref.ids.items():
            for id_ in ids:
                if not (sent and sent.channel == channel and sent.id == id_):
                    queue.append(immp.Receipt(id_, channel).delete())
        if queue:
            await gather(*queue)

    async def _replace_ref(self, msg, channel):
        if not isinstance(msg, immp.Receipt):
            log.debug("Not replacing non-receipt message: %r", msg)
            return msg
        base = None
        if isinstance(msg, immp.SentMessage):
            base = immp.Message(text=msg.text,
                                user=msg.user,
                                action=msg.action,
                                reply_to=msg.reply_to,
                                joined=msg.joined,
                                left=msg.left,
                                title=msg.title,
                                attachments=msg.attachments,
                                raw=msg.raw)
        try:
            ref = await self._cache.get(msg)
        except KeyError:
            log.debug("No match for source message: %r", msg)
            return base
        # Given message was a resync of the source message from a synced channel.
        if ref.ids.get(channel):
            log.debug("Found reference to previously synced message: %r",
                      ref.key)
            at = ref.source.at if isinstance(ref.source,
                                             immp.Receipt) else None
            best = ref.source or msg
            return immp.SentMessage(id_=ref.ids[channel][0],
                                    channel=channel,
                                    at=at,
                                    text=best.text,
                                    user=best.user,
                                    action=best.action,
                                    reply_to=best.reply_to,
                                    joined=best.joined,
                                    left=best.left,
                                    title=best.title,
                                    attachments=best.attachments,
                                    raw=best.raw)
        elif channel.plug == msg.channel.plug:
            log.debug("Referenced message has origin plug, not modifying: %r",
                      msg)
            return msg
        else:
            log.debug(
                "Origin message not referenced in the target channel: %r", msg)
            return base

    async def on_receive(self, sent, source, primary):
        await super().on_receive(sent, source, primary)
        if not primary or not self._accept(source, sent.id):
            return
        try:
            label = self.label_for_channel(sent.channel)
        except immp.ConfigError:
            return
        async with self._lock:
            # No critical section here, just wait for any pending messages to be sent.
            pass
        ref = None
        update = False
        try:
            ref = await self._cache.get(sent)
        except KeyError:
            if sent.deleted:
                log.debug("Ignoring deleted message not in sync cache: %r",
                          sent)
                return
            else:
                log.debug("Incoming message not in sync cache: %r", sent)
        else:
            if sent.deleted:
                if self.config["edits"]:
                    log.debug("Incoming message is a delete, needs sync: %r",
                              sent)
                    await self.delete(ref)
                else:
                    log.debug("Ignoring deleted message: %r", sent)
                return
            elif (sent.edited and not ref.revisions) or ref.revision(sent):
                if self.config["edits"]:
                    log.debug("Incoming message is an update, needs sync: %r",
                              sent)
                    update = True
                else:
                    log.debug("Ignoring updated message: %r", sent)
                    return
            elif all(ref.ids[channel] for channel in self.channels[label]):
                log.debug("Incoming message already synced: %r", sent)
                return
            else:
                log.debug("Incoming message partially synced: %r", sent)
        log.debug("Sending message to synced channel %r: %r", label, sent.id)
        await self.send(label, source, sent, ref, update)
Example #7
0
class _SyncHookBase(immp.Hook):

    _override_config = {
        immp.Optional("joins", immp.Optional.MISSING): bool,
        immp.Optional("renames", immp.Optional.MISSING): bool,
        immp.Optional("reset-author", immp.Optional.MISSING): bool,
        immp.Optional("name-format", immp.Optional.MISSING):
        immp.Nullable(str),
        immp.Optional("strip-name-emoji", immp.Optional.MISSING): bool
    }

    schema = immp.Schema({
        "channels": {
            str: [str]
        },
        immp.Optional("plugs", dict): {
            str: _override_config
        },
        immp.Optional("joins", False): bool,
        immp.Optional("renames", False): bool,
        immp.Optional("identities"): immp.Nullable(str),
        immp.Optional("reset-author", False): bool,
        immp.Optional("name-format"): immp.Nullable(str),
        immp.Optional("strip-name-emoji", False): bool
    })

    _identities = immp.ConfigProperty(IdentityProvider)

    def _plug_config(self, channel):
        keys = tuple(
            immp.Optional.unwrap(key)[0] for key in self._override_config)
        config = {key: self.config[key] for key in keys}
        if channel and channel.plug:
            override = self.config["plugs"].get(channel.plug.name) or {}
            config.update(
                {key: override[key]
                 for key in keys if key in override})
        return config

    def _accept(self, msg, id_):
        config = self._plug_config(
            msg.channel if isinstance(msg, immp.Receipt) else None)
        if not config["joins"] and (msg.joined or msg.left):
            log.debug("Not syncing join/part message: %r", id_)
            return False
        if not config["renames"] and msg.title:
            log.debug("Not syncing rename message: %r", id_)
            return False
        return True

    async def _replace_recurse(self, msg, func, *args):
        # Switch out entire messages for copies or replacements.
        if msg.reply_to:
            msg.reply_to = await func(msg.reply_to, *args)
        attachments = []
        for attach in msg.attachments:
            if isinstance(attach, immp.Message):
                attachments.append(await func(attach, *args))
            else:
                attachments.append(attach)
        msg.attachments = attachments
        return msg

    async def _alter_recurse(self, msg, func, *args):
        # Alter properties on existing cloned message objects.
        await func(msg, *args)
        if msg.reply_to:
            await func(msg.reply_to, *args)
        for attach in msg.attachments:
            if isinstance(attach, immp.Message):
                await func(attach, *args)

    async def _rename_user(self, user, channel):
        config = self._plug_config(channel)
        # Use name-format or identities to render a suitable author real name.
        base = (user.real_name or user.username) if user else None
        name = None
        identity = None
        force = False
        if user and self._identities:
            try:
                identity = await self._identities.identity_from_user(user)
            except Exception:
                log.warning("Failed to retrieve identity information for %r",
                            user,
                            exc_info=True)
        if config["name-format"]:
            if not Template:
                raise immp.PlugError("'jinja2' module not installed")
            title = await channel.title() if channel else None
            context = {"user": user, "identity": identity, "channel": title}
            try:
                name = Template(config["name-format"]).render(**context)
            except TemplateError:
                log.warning("Bad name format template", exc_info=True)
            else:
                # Remove the user's username, so that this name is always used.
                force = True
        elif identity:
            name = "{} ({})".format(base, identity.name)
        if config["strip-name-emoji"]:
            if not EMOJI_REGEX:
                raise immp.PlugError("'emoji' module not installed")
            current = name or base
            if current:
                name = EMOJI_REGEX.sub(_emoji_replace, current).strip()
        if not name:
            return user
        elif config["reset-author"] or not user:
            log.debug("Creating unlinked user with real name: %r", name)
            return immp.User(real_name=name,
                             suggested=(user.suggested if user else False))
        else:
            log.debug("Copying user with new real name: %r -> %r", user, name)
            return immp.User(id_=user.id,
                             plug=user.plug,
                             real_name=name,
                             username=(None if force else user.username),
                             avatar=user.avatar,
                             link=user.link,
                             suggested=user.suggested)

    async def _alter_name(self, msg):
        channel = msg.channel if isinstance(msg, immp.Receipt) else None
        msg.user = await self._rename_user(msg.user, channel)

    async def _alter_identities(self, msg, channel):
        # Replace mentions for identified users in the target channel.
        if not msg.text:
            return
        msg.text = msg.text.clone()
        for segment in msg.text:
            user = segment.mention
            if not user or user.plug == channel.plug:
                # No mention or already matches plug, nothing to do.
                continue
            identity = None
            if self.config["identities"]:
                try:
                    identity = await self._identities.identity_from_user(user)
                except Exception as e:
                    log.warning(
                        "Failed to retrieve identity information for %r",
                        user,
                        exc_info=e)
            # Try to find an identity corresponding to the target plug.
            links = identity.links if identity else []
            for user in links:
                if user.plug == channel.plug:
                    log.debug("Replacing mention: %r -> %r", user, user)
                    segment.mention = user
                    break
            # Perform name substitution on the mention text.
            config = self._plug_config(channel)
            if config["name-format"]:
                at = "@" if segment.text.startswith("@") else ""
                renamed = await self._rename_user(user, channel)
                segment.text = "{}{}".format(at, renamed.real_name)

    async def _send(self, channel, msg):
        try:
            receipts = await channel.send(msg)
            log.debug("Synced IDs in %r: %r", channel,
                      [receipt.id for receipt in receipts])
            return (channel, receipts)
        except Exception:
            log.exception("Failed to relay message to channel: %r", channel)
            return (channel, [])