示例#1
0
    def __init__(self, *args, **kwargs):
        assert kwargs, 'Missing required keyword arguments.'
        assert isinstance(kwargs.get('username'), str), 'Invalid username.'
        assert isinstance(kwargs.get('hostname'),
                          str), 'Invalid WormNET server.'
        assert isinstance(kwargs.get('channels'),
                          list), 'Invalid WormNET channel list.'
        assert isinstance(kwargs.get('port'), int), 'Invalid WormNET port.'
        assert isinstance(kwargs.get('loop'),
                          asyncio.AbstractEventLoop), 'Invalid event loop.'

        self.logger = logging.getLogger('WA_Logger')
        self.wormnet = kwargs.pop('hostname')
        self.nickname = kwargs.pop('username')
        self.channels = dict(
            zip(kwargs.get('channels'),
                [set() for _ in kwargs.get('channels')]))
        self.handlers = dict(
            zip(kwargs.get('channels'), [{} for _ in kwargs.get('channels')]))
        self.commands = dict(
            zip(kwargs.get('channels'),
                [False for _ in kwargs.get('channels')]))
        self.activity = dict(
            zip(kwargs.get('channels'), [{} for _ in kwargs.get('channels')]))
        self.port = kwargs.pop('port')
        self.password = kwargs.pop('password', None)
        self.is_ssl = kwargs.pop('is_ssl', False)
        self.loop = kwargs.pop('loop')
        self.transcode = False
        self.reconnect_delay = 30
        self.server = Server(self.wormnet,
                             self.port,
                             self.is_ssl,
                             password=self.password)
        self.reply_message = kwargs.get('reply_message', 'ArmaBuddy!')
        self.ignore = kwargs.get('ignore', [])
        self.snooper = kwargs.get('snooper', 'WebSnoop')
        self.forward_message = lambda x: x

        # register handlers for every needed internal
        self.connection = IrcProtocol([self.server],
                                      self.nickname,
                                      loop=self.loop)
        self.connection.register_cap('userhost-in-names')
        self.connection.register('*', self.handle_command)
        self.connection.register(
            '002', self.decide_transcode)  # Server name and version
        self.connection.register('376', self.join_channels)  # End of MOTD
        self.connection.register('JOIN', self.handle_entry)
        self.connection.register('PART', self.handle_entry)
        self.connection.register('QUIT', self.handle_entry)
        self.connection.register(
            '353', self.handle_entry
        )  # NAMES reply, lists client name and status for channel

        # horrible horrible hack for a horrible horrible library
        IrcProtocol.connection_lost = __class__.connection_lost
示例#2
0
 async def connect(self) -> None:
     servers = [
         Server(
             self.config['server'], self.config['port'], self.config.get('ssl', False), self.config['pass']
         )
     ]
     self._protocol = IrcProtocol(
         servers, "bnc", user=self.config['user'], loop=self.loop, logger=self.logger
     )
     self._protocol.register('*', self.handle_line)
     await self._protocol.connect()
示例#3
0
 async def start_read_chat(self):
     while not self.started:
         await asyncio.sleep(.5)
     server = [Server("irc.chat.twitch.tv", 6667, password=f'oauth:{self.twitch_key}')]
     self.conn = IrcProtocol(server, nick=self.name, loop=self.loop)
     self.conn.register('PRIVMSG', self.handler)
     await self.conn.connect()
     self.conn.send(f"JOIN #{self.name}")
     print(f"{colored('Ready to read chat commands', 'green')}. "
           f"To see all basic commands type {colored('!stocks', 'magenta')} in the twitch chat")
     self.clear_unsent_buffer()
     if self.currency_system == 'streamlabs_local':
         await self.ping_streamlabs_local()
     # self.conn.send(f"PRIVMSG #{self.name} :I'm testing if this damn thing works")
     # self.conn.send(f"PRIVMSG #{self.name} :hello")
     await asyncio.sleep(24 * 60 * 60 * 365 * 100)
示例#4
0
 def __init__(self,
              loader: ModuleLoader,
              config: ServerConfig,
              loop=None) -> None:
     self._config = config
     self._modules = {}
     self._loader = loader
     self._loop = loop or asyncio.get_event_loop()
     self._conn = IrcProtocol(
         [IrcServer(config.address, config.port, config.ssl)],
         config.nick,
         loop=self._loop,
     )
     self._conn.register("*", self.on_server_message)
     self._active_channels = set()
     self._connected = False
示例#5
0
 async def create(self):
     # Create IRC server connection
     servers = []
     for host in self.hosts:
         servers.append(Server(host['host'], host['port'], host['ssl']))
     self.conn = IrcProtocol(servers=servers,
                             nick=self.nick,
                             realname=self.real_name,
                             logger=self.logger)
     self.conn.register_cap('account-notify')
     self.conn.register('*', self.log)
     self.conn.register('001', self.connected)
     self.conn.register('JOIN', self.on_join)
     self.conn.register('PRIVMSG', self.on_privmsg)
     self.conn.register('TOPIC', self.on_topic)
     self.conn.register('NICK', self.on_nick)
     self.conn.register('PART', self.on_part)
     self.conn.register('KICK', self.on_kick)
     self.conn.register('352', self.on_whoreply)
     self.conn.register('INVITE', self.on_invite)
     self.conn.register('MODE', self.on_mode)
     self.conn.register('332', self.on_topicreply)
示例#6
0
class API:
    def __init__(self, overlord=None, loop=None, prefix="!"):
        self.overlord = overlord
        self.twitch_client_id = 'q4nn0g7b07xfo6g1lwhp911spgutps'
        self._cache = {}
        self.streamlabs_key = ''
        self.twitch_key = ''
        self.twitch_key_requires_refreshing = False
        self.twitch_key_just_refreshed = False
        self.stream_elements_key_requires_refreshing = False
        self.stream_elements_key_just_refreshed = False
        self.stream_elements_key = ''
        self.load_keys()
        self._name = None
        self.users = []
        self.conn = None
        self.prefix = prefix
        self.commands: Dict[str, Union[commands.Command, commands.Group]] = {}
        self.loop = loop
        self.started = False
        self.console_buffer = ['placeholder']

        # self.command_names = {('acquire', None): 'buy', ('my', None): 'my', ('income', 'my'): 'income'}
        self.command_names = {}
        self.load_command_names()
        # self.command_names = {
        #     ('stocks', None): 'stocks',
        #     ('stonks', None): 'stocks',
        # }

        self.console_buffer_done = []
        self.not_sent_buffer = []

        self.streamlabs_local_send_buffer = ''
        self.streamlabs_local_receive_buffer = ''
        self.streamlabs_local_buffer_lock = asyncio.Lock()
        self.streamlabs_local_send_buffer_event = asyncio.Event()
        self.streamlabs_local_receive_buffer_event = asyncio.Event()

    def load_keys(self):
        session = database.Session()
        self.load_key(key_name='streamlabs_key', session=session)
        self.load_key(key_name='twitch_key', session=session)
        self.load_key(key_name='stream_elements_key', session=session)
        self.validate_twitch_token()

    def load_key(self, key_name, session=None):
        if session is None:
            session = database.Session()
        key_db = session.query(database.Settings).get(key_name)
        if key_db:
            setattr(self, key_name, key_db.value)
        else:
            if os.path.exists(f'lib/{key_name}'):
                with open(f'lib/{key_name}', 'rb') as f:
                    key = pickle.load(f)
                session.add(database.Settings(key=f'{key_name}', value=key))
                session.commit()
                os.remove(f'lib/{key_name}')

    def get_user(self):
        """:returns: Information regarding the Streamer through Twitch API. Currently used just for fetching the name."""
        if self.tokens_ready:
            url = "https://api.twitch.tv/helix/users"
            headers = {"Authorization": f'Bearer {self.twitch_key}', 'Client-ID': f'{self.twitch_client_id}'}
            res = requests.get(url, headers=headers)
            if res.status_code == 200:
                return res.json()
            else:
                raise ValueError("Tried fetching user info, but failed. Probably an invalid Twitch Token. Tell Razbi.")
        return ValueError("Tried fetching user info, but tokens are not ready for use. Tell Razbi. "
                          "If this actually happened, it's really really bad.")

    async def create_context(self, username, session):
        user = await self.generate_user(username, session=session)
        return contexts.UserContext(user=user, api=self, session=session)

    async def handler(self, conn, message: Union[Message, str]):
        text: str = message.parameters[1].lower()
        if not text.startswith(self.prefix):
            return
        username = message.prefix.user

        text = text[len(self.prefix):]
        old_command_name, *args = text.split()
        command_name = self.command_names.get((old_command_name, None), old_command_name)
        command_name, _, group_name = command_name.partition(" ")
        if group_name:
            args.insert(0, group_name)
        if command_name in self.commands:
            with contextlib.closing(database.Session()) as session:
                ctx = await self.create_context(username, session)
                try:
                    # noinspection PyTypeChecker
                    await self.commands[command_name](ctx, *args)
                except commands.BadArgumentCount as e:
                    self.send_chat_message(f'@{ctx.user.name} Usage: {self.prefix}{e.usage(name=old_command_name)}')
                except commands.CommandError as e:
                    self.send_chat_message(e.msg)

    @staticmethod
    async def generate_user(name, session=None):
        local_session = False
        if session is None:
            local_session = True
            session = database.Session()
        user = session.query(User).filter_by(name=name).first()
        if not user:
            user = User(name=name)
            session.add(user)
            session.commit()
        if local_session:
            session.close()
        return user

    async def start_read_chat(self):
        while not self.started:
            await asyncio.sleep(.5)
        server = [Server("irc.chat.twitch.tv", 6667, password=f'oauth:{self.twitch_key}')]
        self.conn = IrcProtocol(server, nick=self.name, loop=self.loop)
        self.conn.register('PRIVMSG', self.handler)
        await self.conn.connect()
        self.conn.send(f"JOIN #{self.name}")
        print(f"{colored('Ready to read chat commands', 'green')}. "
              f"To see all basic commands type {colored('!stocks', 'magenta')} in the twitch chat")
        self.clear_unsent_buffer()
        if self.currency_system == 'streamlabs_local':
            await self.ping_streamlabs_local()
        # self.conn.send(f"PRIVMSG #{self.name} :I'm testing if this damn thing works")
        # self.conn.send(f"PRIVMSG #{self.name} :hello")
        await asyncio.sleep(24 * 60 * 60 * 365 * 100)

    def send_chat_message(self, message: str):
        if self.conn and self.conn.connected:
            self.conn.send(f"PRIVMSG #{self.name} :{message}")
            if message != '':
                print(f"{colored('Message sent:', 'cyan')} {colored(message, 'yellow')}")
                self.console_buffer.append(str(message))
        else:
            self.not_sent_buffer.append(message)

    def clear_unsent_buffer(self):
        if self.not_sent_buffer:
            temp_buffer = self.not_sent_buffer
            self.not_sent_buffer = []
            for element in temp_buffer:
                self.send_chat_message(element)

    async def add_points(self, user: str, amount: int):
        if self.tokens_ready:
            if self.currency_system == 'streamlabs':
                url = "https://streamlabs.com/api/v1.0/points/subtract"
                querystring = {"access_token": self.streamlabs_key,
                               "username": user,
                               "channel": self.name,
                               "points": -amount}

                return requests.post(url, data=querystring).json()["points"]
            elif self.currency_system == 'stream_elements':
                url = f'https://api.streamelements.com/kappa/v2/points/{self.stream_elements_id}/{user}/{amount}'
                headers = {'Authorization': f'OAuth {self.stream_elements_key}'}
                res = requests.put(url, headers=headers)
                # print(res.json())
                if res.status_code == 200:
                    return res.json()["newAmount"]
                raise ValueError(f"Error encountered while adding points with stream_elements system. HTTP Code {res.status_code}. "
                                 "Please tell Razbi about it.")

                # return requests.put(url, headers=headers).json()["newAmount"]
            elif self.currency_system == 'streamlabs_local':
                await self.request_streamlabs_local_message(f'!add_points {user} {amount}')
                return 'it worked, I guess'

            raise ValueError(f"Unavailable Currency System: {self.currency_system}")
        raise ValueError("Tokens not ready for use. Tell Razbi about this.")

    async def upgraded_add_points(self, user: User, amount: int, session):
        if amount > 0:
            user.gain += amount
        elif amount < 0:
            user.lost -= amount
        session.commit()
        await self.add_points(user.name, amount)

    def command(self, **kwargs):
        return commands.command(registry=self.commands, **kwargs)

    def group(self, **kwargs):
        return commands.group(registry=self.commands, **kwargs)

    @property
    def name(self):
        if self._name:
            return self._name
        if self.tokens_ready:
            self._name = self.get_user()['data'][0]['login']
            return self._name

    @CachedProperty
    def currency_system(self):
        session = database.Session()
        currency_system_db = session.query(database.Settings).get('currency_system')
        if currency_system_db is not None:
            res = currency_system_db.value
        else:
            res = ''
            session.add(database.Settings(key='currency_system', value=''))
            session.commit()
        return res

    def mark_dirty(self, setting):
        if f'{setting}' in self._cache.keys():
            del self._cache[setting]

    def validate_twitch_token(self):
        if not self.twitch_key:
            return False
        url = 'https://id.twitch.tv/oauth2/validate'
        querystring = {'Authorization': f'OAuth {self.twitch_key}'}
        res = requests.get(url=url, headers=querystring)
        if res.status_code == 200:
            self.twitch_key_requires_refreshing = False
            return True
        elif res.status_code == 401:
            if not self.twitch_key_requires_refreshing:
                print("Twitch Token expired. Refreshing Token whenever possible...")
                self.twitch_key_requires_refreshing = True
            return False
        raise ValueError(
            f"A response code appeared that Razbi didn't handle when validating a twitch token, maybe tell him? Response Code: {res.status_code}")

    def validate_stream_elements_key(self):
        if not self.stream_elements_key:
            return False
        url = 'https://api.streamelements.com/oauth2/validate'
        querystring = {'Authorization': f'OAuth {self.twitch_key}'}
        res = requests.get(url=url, headers=querystring)
        if res.status_code == 200:
            self.stream_elements_key_requires_refreshing = False
            return True
        elif res.status_code == 401:
            print("Stream_elements Token expired. Refreshing Token whenever possible...")
            self.stream_elements_key_requires_refreshing = True
            return False
        elif res.status_code >= 500:
            print("server errored or... something. better tell Razbi")
            return False
        raise ValueError(
            f"A response code appeared that Razbi didn't handle when validating a stream_elements token, maybe tell him? Response Code: {res.status_code}")

    @property
    def tokens_ready(self):
        if 'tokens_ready' in self._cache:
            return self._cache['tokens_ready']
        if self.currency_system and self.twitch_key and self.validate_twitch_token():
            if self.currency_system == 'streamlabs' and self.streamlabs_key or \
                    self.currency_system == 'stream_elements' and self.stream_elements_key:
                self._cache['tokens_ready'] = True
                return True

            if self.currency_system == 'streamlabs_local':
                self._cache['tokens_ready'] = True

                # self.send_chat_message('!connect_minigame')
                # print("connected")
                return True
        return False

    @CachedProperty
    def stream_elements_id(self):
        if self.tokens_ready:
            session = database.Session()
            stream_elements_id_db = session.query(database.Settings).get('stream_elements_id')
            if stream_elements_id_db:
                return stream_elements_id_db.value
            url = f'https://api.streamelements.com/kappa/v2/channels/{self.name}'
            headers = {'accept': 'application/json'}
            res = requests.get(url, headers=headers)
            if res.status_code == 200:
                stream_elements_id = res.json()['_id']
                session.add(database.Settings(key='stream_elements_id', value=stream_elements_id))
                return stream_elements_id

    @CachedProperty
    def twitch_key_expires_at(self):
        session = database.Session()
        expires_at_db = session.query(database.Settings).get('twitch_key_expires_at')
        if expires_at_db:
            expires_at = int(expires_at_db.value)
            return expires_at

    @CachedProperty
    def stream_elements_key_expires_at(self):
        session = database.Session()
        expires_at_db = session.query(database.Settings).get('stream_elements_key_expires_at')
        if expires_at_db:
            expires_at = int(expires_at_db.value)
            return expires_at

    async def twitch_key_auto_refresher(self):
        while True:
            if self.tokens_ready and self.twitch_key_expires_at and time.time() + 60 > self.twitch_key_expires_at:
                url = 'https://razbi.funcity.org/stocks-chat-minigame/twitch/refresh_token'
                querystring = {'access_token': self.twitch_key}
                res = requests.get(url, params=querystring)
                if res.status_code == 200:
                    session = database.Session()
                    twitch_key_db = session.query(database.Settings).get('twitch_key')
                    if twitch_key_db:
                        twitch_key_db.value = res.json()['access_token']
                    expires_at_db = session.query(database.Settings).get('twitch_key_expires_at')
                    if expires_at_db:
                        expires_at_db.value = str(int(time.time()) + res.json()['expires_in'])
                    self.mark_dirty('twitch_key_expires_at')
                    session.commit()
                    print("Twitch key refreshed successfully.")
                elif res.status_code == 500:
                    print(
                        "Tried refreshing the twitch token, but the server is down or smth, please tell Razbi about this. ")
                else:
                    raise ValueError('Unhandled status code when refreshing the twitch key. TELL RAZBI',
                                     res.status_code)

            elif self.tokens_ready and self.currency_system == 'stream_elements' and \
                    self.stream_elements_key_expires_at and time.time() + 60 > self.stream_elements_key_expires_at:
                url = 'https://razbi.funcity.org/stocks-chat-minigame/stream_elements/refresh_token'
                querystring = {'access_token': self.stream_elements_key}
                res = requests.get(url, params=querystring)
                if res.status_code == 200:
                    session = database.Session()
                    stream_element_key_db = session.query(database.Settings).get('stream_elements_key')
                    if stream_element_key_db:
                        stream_element_key_db.value = res.json()['access_token']
                    expires_at_db = session.query(database.Settings).get('stream_elements_key_expires_at')
                    if expires_at_db:
                        expires_at_db.value = str(int(time.time()) + res.json()['expires_in'])
                    self.mark_dirty('stream_elements_key_expires_at')
                    session.commit()
                    print("Stream_elements key refreshed successfully.")
                elif res.status_code == 500:
                    print(
                        "Tried refreshing the stream_elements token, but the server is down or smth, please tell Razbi about this. ")
                else:
                    raise ValueError('Unhandled status code when refreshing the stream_elements key. TELL RAZBI',
                                     res.status_code)

            await asyncio.sleep(60)

    async def ping_streamlabs_local(self):
        self.send_chat_message('!connect_minigame')
        await self.request_streamlabs_local_message(f'!get_user_points {self.name}')

    async def request_streamlabs_local_message(self, message: str):
        async with self.streamlabs_local_buffer_lock:
            self.streamlabs_local_send_buffer = message
            self.streamlabs_local_send_buffer_event.set()
            await self.streamlabs_local_receive_buffer_event.wait()
            self.streamlabs_local_receive_buffer_event.clear()
            response = self.streamlabs_local_receive_buffer
        return response

    def get_and_format(self, ctx: contexts.UserContext, message_name: str, **formats):
        if message_name not in self.overlord.messages:
            default_messages = load_message_templates()
            if message_name in default_messages:
                self.overlord.messages[message_name] = default_messages[message_name]
            else:
                return ''

        return self.overlord.messages[message_name].format(stocks_alias=self.overlord.messages['stocks_alias'],
                                                           company_alias=self.overlord.messages['company_alias'],
                                                           user_name=ctx.user.name,
                                                           currency_name=self.overlord.currency_name,
                                                           stocks_limit=f'{self.overlord.max_stocks_owned:,}',
                                                           **formats)

    def load_command_names(self):
        # self.command_names = {('acquire', None): 'buy', ('my', None): 'my', ('income', 'my'): 'income'}

        session = database.Session()
        command_names = session.query(database.Settings).get('command_names')
        if command_names:
            self.command_names = BidirectionalMap(ast.literal_eval(command_names.value))
            default_command_names = load_command_names()
            for key in default_command_names:
                if key not in self.command_names or self.command_names[key] != default_command_names[key]:
                    self.command_names[key] = default_command_names[key]
        else:
            self.command_names = load_command_names()
示例#7
0
async def run():
    client = InfluxDBClient.from_dsn(config['database']['url'], verify_ssl=True)
    databases = [db['name'] for db in client.get_list_database()]

    if client._database not in databases:
        client.create_database(client._database)

    irc_conf = config['irc']
    servers = [Server(irc_conf['server'], irc_conf['port'], irc_conf.get('ssl', False))]
    interval = irc_conf['interval']
    conn_commands = irc_conf.get('connect_commands', [])
    async with IrcProtocol(servers, irc_conf['nick'], irc_conf['user']) as proto:
        await proto.connect()
        await proto.wait_for('376')

        futures = {}

        proto.register('251', on_luser_start(futures))
        proto.register('252', on_opers_online(futures))
        proto.register('254', on_chans_formed(futures))
        proto.register('266', on_global_users(futures))

        for cmd in conn_commands:
            proto.send(cmd)

        while True:
            futures.update(
                (name, proto.loop.create_future())
                for name in (
                    'opers', 'channels', 'user_count', 'user_max', 'server_count', 'normal_users', 'invisible_users'
                )
            )

            proto.send("LUSERS")
            try:
                await asyncio.wait_for(asyncio.gather(*futures.values()), interval)
            except asyncio.TimeoutError:
                pass
            else:
                data = {
                    key: await fut
                    for key, fut in futures.items()
                }

                tags = config['database']['tags']

                body = [
                    {
                        'measurement': 'user_count',
                        'fields': {
                            "oper_count": data['opers'],
                            "user_count": data['user_count'],
                            "user_max": data['user_max'],
                            "visible_users": data["normal_users"],
                            "invisible_users": data["invisible_users"],
                        }
                    },
                    {
                        'measurement': 'channel_count',
                        'fields': {
                            "channels": data['channels'],
                        }
                    },
                    {
                        'measurement': 'server_count',
                        'fields': {
                            "server_count": data['server_count'],
                        }
                    },
                ]

                try:
                    client.write_points(body, tags=tags)
                except RequestException:
                    traceback.print_exc()

                futures.clear()

                await asyncio.sleep(interval)
示例#8
0
class Conn:
    def __init__(self, handlers) -> None:
        self.run_dir = Path().resolve()
        self._protocol = None
        self.handlers = handlers
        self.futures = {}
        self.locks = defaultdict(asyncio.Lock)
        self.loop = asyncio.get_event_loop()
        self.bnc_data = {}
        self.stopped_future = self.loop.create_future()
        self.get_users_state = 0
        self.config = {}
        if not self.log_dir.exists():
            self.log_dir.mkdir()

        self.setup_logger()
        self.logger = logging.getLogger("bncbot")

    def setup_logger(self):
        do_debug = self.config.get("debug", False)
        log_to_file = self.config.get("log_to_file", False)

        logging_conf = {
            "version": 1,
            "formatters": {
                "brief": {
                    "format": "[%(asctime)s] [%(levelname)s] %(message)s",
                    "datefmt": "%H:%M:%S"
                },
                "full": {
                    "format": "[%(asctime)s] [%(levelname)s] %(message)s",
                    "datefmt": "%Y-%m-%d][%H:%M:%S"
                }
            },
            "handlers": {
                "console": {
                    "class": "logging.StreamHandler",
                    "formatter": "brief",
                    "level": "DEBUG",
                    "stream": "ext://sys.stdout"
                }
            },
            "loggers": {
                "bncbot": {
                    "level": "DEBUG",
                    "handlers": ["console"]
                },
                "asyncio": {
                    "level": "DEBUG",
                    "handlers": ["console"]
                }
            }
        }

        if log_to_file:
            logging_conf['handlers']['file'] = {
                "class": "logging.handlers.RotatingFileHandler",
                "maxBytes": 1000000,
                "backupCount": 5,
                "formatter": "full",
                "level": "INFO",
                "encoding": "utf-8",
                "filename": self.log_dir / "bot.log"
            }

            logging_conf['loggers']['bncbot']['handlers'].append('file')

            if do_debug:
                logging_conf['handlers']['debug_file'] = {
                    "class": "logging.handlers.RotatingFileHandler",
                    "maxBytes": 1000000,
                    "backupCount": 5,
                    "formatter": "full",
                    "encoding": "utf-8",
                    "level": "DEBUG",
                    "filename": self.log_dir / "debug.log"
                }
                logging_conf['loggers']['asyncio']['handlers'].append('debug_file')

    def load_config(self) -> None:
        with self.config_file.open(encoding='utf8') as f:
            self.config = json.load(f)

    def load_data(self, update: bool = False) -> None:
        """Load cached BNC information from the file"""
        self.bnc_data = {}
        if self.data_file.exists():
            with self.data_file.open(encoding='utf8') as f:
                self.bnc_data = json.load(f)

        self.bnc_data.setdefault('queue', {})
        self.bnc_data.setdefault('users', {})
        self.save_data()
        if update and not self.bnc_users:
            asyncio.ensure_future(self.get_user_hosts(), loop=self.loop)

    def save_data(self) -> None:
        with self.data_file.open('w', encoding='utf8') as f:
            json.dump(self.bnc_data, f, indent=2, sort_keys=True)

    def run(self) -> bool:
        self.load_config()
        self.loop.run_until_complete(self.connect())
        self.load_data(True)
        self.start_timers()
        restart = self.loop.run_until_complete(self.stopped_future)
        self.loop.stop()
        return restart

    def create_timer(self, interval, func, *args, initial_interval=None):
        asyncio.ensure_future(timer(interval, func, *args, initial_interval=initial_interval), loop=self.loop)

    def start_timers(self) -> None:
        self.create_timer(timedelta(hours=8), self.get_user_hosts)

    def send(self, *parts) -> None:
        self._protocol.send(' '.join(parts))

    def module_msg(self, name: str, cmd: str) -> None:
        self.msg(self.prefix + name, cmd)

    async def get_user_hosts(self) -> None:
        """Should only be run periodically to keep the user list in sync"""
        self.get_users_state = 0
        self.bnc_users.clear()
        user_list_fut = self.loop.create_future()
        self.futures["user_list"] = user_list_fut
        self.send("znc listusers")
        await user_list_fut
        del self.futures["user_list"]

        for user in self.bnc_users:
            bindhost_fut = self.loop.create_future()
            self.futures["bindhost"] = bindhost_fut
            self.module_msg("controlpanel", f"Get BindHost {user}")
            self.bnc_users[user] = await bindhost_fut
            del self.futures["bindhost"]

        self.save_data()
        self.load_data()
        host_map = defaultdict(list)
        for user, host in self.bnc_users.items():
            host_map[host].append(user)

        hosts = {
            host: users
            for host, users in host_map.items()
            if host and len(users) > 1
        }
        if hosts:
            self.chan_log(
                "WARNING: Duplicate BindHosts found: {}".format(
                    hosts
                )
            )

    async def connect(self) -> None:
        servers = [
            Server(
                self.config['server'], self.config['port'], self.config.get('ssl', False), self.config['pass']
            )
        ]
        self._protocol = IrcProtocol(
            servers, "bnc", user=self.config['user'], loop=self.loop, logger=self.logger
        )
        self._protocol.register('*', self.handle_line)
        await self._protocol.connect()

    def close(self) -> None:
        self._protocol.quit()

    async def shutdown(self, restart=False):
        self.chan_log("Bot {}...".format("shutting down" if not restart else "restarting"))
        await asyncio.sleep(1)
        self.close()
        await asyncio.sleep(1, loop=self.loop)
        self.stopped_future.set_result(restart)

    async def handle_line(self, proto: 'IrcProtocol', line: 'Message') -> None:
        raw_event = irc.make_event(self, line, proto)
        for handler in self.handlers.get('raw', {}).get('', []):
            await self.launch_hook(raw_event, handler)

    async def launch_hook(self, event, func) -> bool:
        try:
            params = [
                getattr(event, name)
                for name in inspect.signature(func).parameters.keys()
            ]
            await call_func(func, *params)
        except Exception as e:
            self.logger.exception("Error occurred in hook")
            self.chan_log(f"Error occurred in hook {func.__name__} '{type(e).__name__}: {e}'")
            return False
        else:
            return True

    def is_admin(self, mask: str) -> bool:
        return any(fnmatch(mask.lower(), pat.lower()) for pat in self.admins)

    async def is_bnc_admin(self, name) -> bool:
        lock = self.locks["controlpanel_bncadmin"]
        async with lock:
            fut = self.futures.setdefault("bncadmin", self.loop.create_future())
            self.module_msg("controlpanel", "Get Admin {}".format(name))
            res = await fut
            del self.futures["bncadmin"]

        return res

    def add_queue(self, nick: str, registered_time: str) -> None:
        self.bnc_queue[nick] = registered_time
        self.save_data()

    def rem_queue(self, nick: str) -> None:
        if nick in self.bnc_queue:
            del self.bnc_queue[nick]
            self.save_data()

    def chan_log(self, msg: str) -> None:
        if self.log_chan:
            self.msg(self.log_chan, msg)

    def add_user(self, nick: str) -> bool:
        if not util.is_username_valid(nick):
            username = util.sanitize_username(nick)
            self.chan_log(f"WARNING: Invalid username '{nick}'; sanitizing to {username}")
        else:
            username = nick

        passwd = util.gen_pass()
        try:
            host = self.get_bind_host()
        except ValueError:
            return False

        self.module_msg('controlpanel', f"cloneuser BNCClient {username}")
        self.module_msg('controlpanel', f"Set Password {username} {passwd}")
        self.module_msg('controlpanel', f"Set BindHost {username} {host}")
        self.module_msg('controlpanel', f"Set Nick {username} {nick}")
        self.module_msg('controlpanel', f"Set AltNick {username} {nick}_")
        self.module_msg('controlpanel', f"Set Ident {username} {nick}")
        self.module_msg('controlpanel', f"Set Realname {username} {nick}")
        self.send('znc saveconfig')
        self.module_msg('controlpanel', f"reconnect {username} Snoonet")
        self.msg(
            "MemoServ",
            f"SEND {nick} Your BNC auth is Username: {username} Password: "******"{passwd} (Ports: 5457 for SSL - 5456 for NON-SSL) Help: "
            f"/server bnc.snoonet.org 5456 and /PASS {username}:{passwd}"
        )
        self.bnc_users[username] = host
        self.save_data()
        return True

    def get_bind_host(self) -> str:
        for _ in range(50):
            host = str(util.get_random_address(self.bind_host_net))
            if host not in self.bnc_users.values():
                return host
        else:
            self.chan_log(
                "ERROR: get_bind_host() has hit a bindhost collision"
            )
            raise ValueError

    def msg(self, target: str, *messages: str) -> None:
        for message in messages:
            self.send(f"PRIVMSG {target} :{message}")

    def notice(self, target: str, *messages: str) -> None:
        for message in messages:
            self.send(f"NOTICE {target} :{message}")

    @property
    def admins(self) -> List[str]:
        return self.config.get('admins', [])

    @property
    def bnc_queue(self) -> Dict[str, str]:
        return self.bnc_data.setdefault('queue', {})

    @property
    def bnc_users(self) -> Dict[str, str]:
        return self.bnc_data.setdefault('users', {})

    @property
    def prefix(self) -> str:
        return self.config.get('status_prefix', '*')

    @property
    def cmd_prefix(self):
        return self.config.get('command_prefix', '.')

    @property
    def log_chan(self) -> Optional[str]:
        return self.config.get('log_channel')

    @property
    def bind_host_net(self):
        return ipaddress.ip_network(self.config.get('bind_host_net', "127.0.0.0/16"))

    @property
    def log_dir(self):
        return self.run_dir / "logs"

    @property
    def data_file(self):
        return self.run_dir / "bnc.json"

    @property
    def config_file(self):
        return self.run_dir / "config.json"

    @property
    def nick(self) -> str:
        return self._protocol.nick

    @nick.setter
    def nick(self, value: str) -> None:
        self._protocol.nick = value
示例#9
0
class WA_IRC:
    def __init__(self, *args, **kwargs):
        assert kwargs, 'Missing required keyword arguments.'
        assert isinstance(kwargs.get('username'), str), 'Invalid username.'
        assert isinstance(kwargs.get('hostname'),
                          str), 'Invalid WormNET server.'
        assert isinstance(kwargs.get('channels'),
                          list), 'Invalid WormNET channel list.'
        assert isinstance(kwargs.get('port'), int), 'Invalid WormNET port.'
        assert isinstance(kwargs.get('loop'),
                          asyncio.AbstractEventLoop), 'Invalid event loop.'

        self.logger = logging.getLogger('WA_Logger')
        self.wormnet = kwargs.pop('hostname')
        self.nickname = kwargs.pop('username')
        self.channels = dict(
            zip(kwargs.get('channels'),
                [set() for _ in kwargs.get('channels')]))
        self.handlers = dict(
            zip(kwargs.get('channels'), [{} for _ in kwargs.get('channels')]))
        self.commands = dict(
            zip(kwargs.get('channels'),
                [False for _ in kwargs.get('channels')]))
        self.activity = dict(
            zip(kwargs.get('channels'), [{} for _ in kwargs.get('channels')]))
        self.port = kwargs.pop('port')
        self.password = kwargs.pop('password', None)
        self.is_ssl = kwargs.pop('is_ssl', False)
        self.loop = kwargs.pop('loop')
        self.transcode = False
        self.reconnect_delay = 30
        self.server = Server(self.wormnet,
                             self.port,
                             self.is_ssl,
                             password=self.password)
        self.reply_message = kwargs.get('reply_message', 'ArmaBuddy!')
        self.ignore = kwargs.get('ignore', [])
        self.snooper = kwargs.get('snooper', 'WebSnoop')
        self.forward_message = lambda x: x

        # register handlers for every needed internal
        self.connection = IrcProtocol([self.server],
                                      self.nickname,
                                      loop=self.loop)
        self.connection.register_cap('userhost-in-names')
        self.connection.register('*', self.handle_command)
        self.connection.register(
            '002', self.decide_transcode)  # Server name and version
        self.connection.register('376', self.join_channels)  # End of MOTD
        self.connection.register('JOIN', self.handle_entry)
        self.connection.register('PART', self.handle_entry)
        self.connection.register('QUIT', self.handle_entry)
        self.connection.register(
            '353', self.handle_entry
        )  # NAMES reply, lists client name and status for channel

        # horrible horrible hack for a horrible horrible library
        IrcProtocol.connection_lost = __class__.connection_lost

    async def connect(self):
        # begin connection
        self.logger.warning(' * Connecting to WormNET.')
        if await self.connection._connect(server=self.server):
            self.logger.warning(' * Connected to WormNET.')

        # wait for end of MOTD to signal proper connection
        if not await self.connection.wait_for('376',
                                              timeout=self.reconnect_delay):
            self.logger.warning(
                ' ! Unable to connect to properly WormNET, attempting to reconnect.'
            )
            return await self.connect()

        # wait until we lose connection
        while self.connection._connected:
            await asyncio.sleep(1)

        # if connection has died, attempt to restart it
        self.logger.warning(
            f' ! Disconnected from WormNET, attempting to reconnect in {self.reconnect_delay} seconds.'
        )
        await asyncio.sleep(self.reconnect_delay)
        return await self.connect()

    async def decide_transcode(self, conn, message):
        # check if this is the community server, if so, disable transcoding of messages by monkey-patching IrcProtocol.data_received
        if len(message.parameters
               ) >= 2 and 'ae.net.irc.server/WormNET' in message.parameters[1]:
            self.logger.warning(
                ' * Disabled transcoding for WormNET messages.')
            IrcProtocol.data_received = __class__.transcode_off
            self.transcode = False
        else:
            self.logger.warning(' * Enabled transcoding for WormNET messages.')
            IrcProtocol.data_received = __class__.transcode_on
            self.transcode = True

    @staticmethod
    def transcode_off(self, data: bytes) -> None:
        self._buff += data
        while b'\r\n' in self._buff:
            raw_line, self._buff = self._buff.split(b'\r\n', 1)
            message = Message.parse(raw_line)
            for trigger, func in self.handlers.values():
                if trigger in (message.command, '*'):
                    self.loop.create_task(func(self, message))

    @staticmethod
    def transcode_on(self, data: bytes) -> None:
        self._buff += data
        while b'\r\n' in self._buff:
            raw_line, self._buff = self._buff.split(b'\r\n', 1)
            raw_line = WA_Encoder.decode(raw_line)
            message = Message.parse(raw_line)
            for trigger, func in self.handlers.values():
                if trigger in (message.command, '*'):
                    self.loop.create_task(func(self, message))

    @staticmethod
    def connection_lost(self, exc) -> None:
        self._transport = None
        self._connected = False
        if self._quitting:
            self.quit_future.set_result(None)

    async def log(self, conn, message):
        self.logger.info(f' * IRC_RAW {message}')

    async def handle_entry(self, conn, message):
        channel = message.parameters[0][1:].lower()
        if message.command == 'JOIN':  # add user to channel set if joining
            if channel in self.channels:
                self.channels[channel].add(message.prefix.nick)
        elif message.command == 'PART':  # remove user from channel set if parting
            if channel in self.channels:
                self.channels[channel].discard(message.prefix.nick)
        elif message.command == 'QUIT':  # remove user from all channel sets if quitting
            for channel in self.channels:
                if self.channels[channel]:
                    self.channels[channel].discard(message.prefix.nick)
        elif message.command == '353':
            # strip any modes from users, should not be set on WormNET, but will make testing a pain on regular networks
            no_modes = message.parameters[3].translate(
                {ord(i): None
                 for i in '@+$%'})
            users = no_modes.split()
            channel = message.parameters[2][1:].lower()
            for user in users:
                self.channels[channel].add(user.split('!')[0])

    async def join_channels(self, conn, message):
        for channel_name, settings in self.channels.items():
            # create new set containing user list
            self.logger.warning(
                f' * Joining WormNET channel: #{channel_name}!')
            self.connection.send(f'JOIN #{channel_name}')

            # give server a few seconds to give us NAMES, if none has been received after timeout propagate error
            result = await self.connection.wait_for('366', timeout=30)
            if result is None:
                raise Exception(
                    f'Never received NAMES after joining WormNET channel #{channel_name}.'
                )

    async def send_message(self, guild, origin, channel, message):
        # strip everything after \n to avoid sneaky user sending multiple commands in single string
        message = message.split('\n')[0]
        message = f'PRIVMSG #{channel} :{message}'

        # keep message under 250 characters at least, in reality max length is 512
        message = (message[:250] + '[...]') if len(message) > 250 else message
        self.logger.warning(
            f' * Forwarding message from #{origin} on "{guild}" to WormNET #{channel}: {message}'
        )
        await self.transport_write(message)

    async def send_private(self, user, message):
        # strip everything after \n to avoid sneaky user sending multiple commands in single string
        message = message.split('\n')[0]
        message = f'PRIVMSG {user} :{message}'

        self.logger.warning(
            f' * Forwarding PM to WormNET user {user}: {message}')
        await self.transport_write(message)

    async def transport_write(self, message):
        if self.connection._connected:
            message = message + '\r\n'
            # W:A client transforms some of the cyrillic characters to latin when typing, in addition to encoding
            message = WA_Encoder.translate(message)
            message = WA_Encoder.encode(
                message) if self.transcode else message.encode()
            self.connection._transport.write(message)
        else:
            self.logger.warning(
                ' ! Could not forward message due to connection to IRC being down!'
            )

    async def handle_command(self, connection, message):
        if message.command == '432':
            raise Exception('Requested nickname contains invalid characters')
        if message.command == '433':
            raise Exception('Requested nickname is already in use')

        # ignore commands triggered by self
        if message.prefix.nick == self.nickname:
            return

        # if destination is a channel call handler
        if message.parameters[0][0] == '#':
            channel = message.parameters[0][1:].lower()

            # if user writes in a channel, update activity
            if message.command == 'PRIVMSG':
                self.activity[channel][message.prefix.nick] = datetime.now(
                    timezone.utc)

            if channel in self.channels and message.command in self.handlers[
                    channel]:
                return await self.handlers[channel][message.command
                                                    ](connection, message)

        # reply to all PM with predefined phrase
        elif message.parameters[0][0] != '#' and message.command == 'PRIVMSG':
            self.logger.warning(
                f' * Received PM on WormNET from {message.prefix.nick}: {message.parameters[1]}'
            )
            return await self.send_private(user=message.prefix.nick,
                                           message=self.reply_message)

    async def default_privmsg_handler(self, connection, message):
        # lowercase channel / username
        message.parameters[0] = message.parameters[0].lower()

        # check if channel allows commands
        if len(message.parameters[1]) and message.parameters[1][0] == '!':
            if not self.commands[message.parameters[0][1:]]:
                return self.logger.warning(
                    f' * Ignoring command in {message.parameters[0]} from {message.prefix.nick}: {message.parameters[1]}'
                )

        # handle actions
        if message.parameters[1][:7] == '\x01ACTION':
            message.parameters[
                1] = f'~ {message.prefix.nick} {message.parameters[1][8:-1]} ~'
            # @ngrfisk you could check for "\x01ACTION is joining a game" or "\x01ACTION is hosting a game"
            return self.logger.warning(
                f' * Ignoring action in {message.parameters[0]} from {message.prefix.nick}: {message.parameters[1]}'
            )

        # process PRIVMSG
        channel = message.parameters[0][1:].lower()
        sender = message.prefix.nick
        message = message.parameters[1]
        snooper = None

        # if user is sending from web snoop, then we could still match avatar
        if sender == self.snooper:
            snooper = sender
            match = re.match(r'^(?P<sender>.*?)>\s(?P<message>.*)$', message)
            sender = match.group('sender')
            message = match.group('message')

        # ignore messages from users on ignore list
        if sender in self.ignore:
            return self.logger.warning(
                f' * Ignored WormNET message from {sender}.')

        await self.forward_message(irc_channel=channel,
                                   sender=sender,
                                   message=message,
                                   snooper=snooper)
示例#10
0
class IRC(Service):

    db_type = DBIRCService
    db_chat_type = DBIRCChat
    db_user_type = DBIRCUser
    db_chat_user_type = DBIRCChatUser
    db_message_type = DBIRCMessage
    db_event_type = DBIRCEvent
    db_bridge_chat_type = DBIRCBridgeChat

    class IRCChat(Chat):
        def __init__(self, service, name, topic=None):
            self.db_type = service.db_chat_type
            self._topic = topic
            self.child_attrs = ['topic']
            super().__init__(service, name)

        @property
        def topic(self):
            return self._topic

        @topic.setter
        def topic(self, value):
            self._topic = value
            self.save()

        async def join(self):
            self.service.conn.send("JOIN {}".format(self.name))

        async def send(self, message):
            self.service.conn.send("PRIVMSG {0} :{1}".format(
                self.name, message))

        async def receive(self, message, chat_user):
            # Build context
            text = message.parameters[1][1:]
            nick = message.prefix.nick

            await self.bridge_chat.receive(text, chat_user)

        async def query_users(self):
            self.service.conn.send("WHO {0}".format(self.name))

    class IRCUser(User):
        def __init__(self,
                     service,
                     name,
                     ident=None,
                     host=None,
                     real_name=None,
                     server=None):
            self.db_type = service.db_user_type
            self._ident = ident
            self._host = host
            self._real_name = real_name
            self._server = server
            self.child_attrs = ['ident', 'host', 'real_name', 'server']

            super().__init__(service, name)

        @property
        def ident(self):
            return self._ident

        @ident.setter
        def ident(self, value):
            self._ident = value
            self.save()

        @property
        def host(self):
            return self._host

        @host.setter
        def host(self, value):
            self._host = value
            self.save()

        @property
        def real_name(self):
            return self._real_name

        @real_name.setter
        def real_name(self, value):
            self._real_name = value
            self.save()

        @property
        def server(self):
            return self._server

        @server.setter
        def server(self, value):
            self._server = value
            self.save()

    class IRCChatUser(ChatUser):
        def __init__(self, service, chat, user):
            self.db_type = service.db_chat_user_type
            self._operator = False
            self._voiced = False
            self._child_attrs = ['operator', 'voiced']

            super().__init__(service, chat, user)

        @property
        def operator(self):
            return self._operator

        @operator.setter
        def operator(self, value: bool):
            self._operator = value
            self.logger.debug("Set operator status for {0} to {1}".format(
                self, value))
            self.save()

        @property
        def voiced(self):
            return self._voiced

        @voiced.setter
        def voiced(self, value: bool):
            self._voiced = value
            self.save()

    class IRCBridgeChat(BridgeChat):
        def __init__(self, bridge, chat):
            self.db_type = chat.service.db_bridge_chat_type

            super().__init__(bridge, chat)

    class IRCMessage(Message):
        def __init__(self, service, chat, user, message):
            self.db_type = service.db_message_type

            super().__init__(service, chat, user, message)

    class IRCEvent(Event):
        def __init__(self,
                     service,
                     event,
                     new_value=None,
                     old_value=None,
                     chat=None,
                     user=None):
            self.db_type = service.db_event_type

            super().__init__(service,
                             event,
                             new_value=new_value,
                             old_value=old_value,
                             chat=chat,
                             user=user)

    class IRCTopicEvent(IRCEvent):
        def __init__(self, service, chat, user, topic):
            super().__init__(service,
                             'topic_set',
                             new_value=topic,
                             old_value=chat.topic,
                             chat=chat,
                             user=user)

    class IRCJoinEvent(IRCEvent):
        def __init__(self, service, chat, user):
            super().__init__(service, 'user_joined', chat=chat, user=user)

    class IRCLeaveEvent(IRCEvent):
        def __init__(self, service, chat, user):
            super().__init__(service, 'user_left', chat=chat, user=user)

    class IRCQuitEvent(IRCEvent):
        def __init__(self, service, chat, user):
            super().__init__(service, 'user_quit', chat=chat, user=user)

    class IRCNickEvent(IRCEvent):
        def __init__(self, service, user, new_nick):
            super().__init__(service,
                             'user_nick',
                             user=user,
                             new_value=new_nick,
                             old_value=user.name)

    def __init__(self, bot, id, name, enabled, hosts, nick, real_name,
                 channels):
        self.chat_class = self.IRCChat
        self.user_class = self.IRCUser
        self.chat_user_class = self.IRCChatUser
        self.message_class = self.IRCMessage
        self.event_class = self.IRCEvent
        self.bridge_chat_class = self.IRCBridgeChat

        super().__init__(bot, id, name, enabled)

        # Initialise ourselves
        self.id = id
        self.hosts = hosts
        self.nick = nick
        self.real_name = real_name
        self.channels = channels

        self.logger.info("Initialising IRC server {0}".format(id))
        self.logger.debug("{0}: hosts: {1}, nick: {2}, realname: {3}".format(
            self.name, self.hosts, self.nick, self.real_name))

    async def create(self):
        # Create IRC server connection
        servers = []
        for host in self.hosts:
            servers.append(Server(host['host'], host['port'], host['ssl']))
        self.conn = IrcProtocol(servers=servers,
                                nick=self.nick,
                                realname=self.real_name,
                                logger=self.logger)
        self.conn.register_cap('account-notify')
        self.conn.register('*', self.log)
        self.conn.register('001', self.connected)
        self.conn.register('JOIN', self.on_join)
        self.conn.register('PRIVMSG', self.on_privmsg)
        self.conn.register('TOPIC', self.on_topic)
        self.conn.register('NICK', self.on_nick)
        self.conn.register('PART', self.on_part)
        self.conn.register('KICK', self.on_kick)
        self.conn.register('352', self.on_whoreply)
        self.conn.register('INVITE', self.on_invite)
        self.conn.register('MODE', self.on_mode)
        self.conn.register('332', self.on_topicreply)

    async def start(self):
        if self.enabled:
            await self.create()
            await self.conn.connect()
        else:
            self.logger.info("{0} is currently disabled".format(self.id))
            return

    def connection_lost(self):
        self.logger.debug("Connection lost.")

    async def connected(self, conn, message):
        self.logger.info("Connected!")

        # Join configured channels, and create/register with bridges
        for channel in self.channels:
            self.logger.debug("Joining channel {0}...".format(channel['name']))
            await self.create_chat(channel)

    async def create_chat(self, channel):
        chat = self.IRCChat(self, channel['name'])

        #if 'bridges' not in channel:
        #    bridges = [None]
        #else:
        #    bridges = channel['bridges']

        # Get/create bridges
        #for bridge_name in bridges:
        #    bridge = self.bot.get_bridge(bridge_name)
        #    bridge.add(chat)
        #    chat.bridges.append(bridge)

        #self.chats[chat.name] = chat
        await chat.join()
        #await chat.send("Hello {}!".format(chat.name))

    async def log(self, conn, message):
        self.logger.debug(message)

    async def user_from_message(self, message):
        nick, ident, host = message.prefix
        return await self.user_from_tuple(nick, ident, host)

    async def user_from_tuple(self, nick, ident, host):
        # Check to see if the service has this user already
        for user in self.users:
            self.logger.debug(user)
            self.logger.debug(
                "Matching {0} = {1}, {2} = {3}, {4} = {5}".format(
                    nick, user.name, ident, user.ident, host, user.host))
            if user.name == nick and user.ident == ident and user.host == host:
                self.logger.debug(
                    "Found existing user {0} in service".format(user))
                return user

        self.logger.debug(
            "Didn't find user in service already, creating new object...")
        u = self.IRCUser(self, nick, ident, host)
        self.add_user(u)
        return u

    async def chat_from_message(self, message):
        name = message.parameters[0]
        return await self.chat_from_name(name)

    async def get_chat_and_user_from_message(self, message):
        chat = await self.chat_from_message(message)
        user = await self.user_from_message(message)
        chat_user = await chat.get_chat_user(user)

        return chat, user, chat_user

    async def chat_from_name(self, name):

        # Check to see if the service has this chat already
        if name in self.chats:
            chat = self.chats[name]
            self.logger.debug(
                "Found existing chat {0} in service".format(chat))
            return chat

        self.logger.debug(
            "Didn't find chat in service already, creating new object...")
        c = self.IRCChat(self, name)
        self.add_chat(c)
        return c

    async def on_join(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)
        self.logger.debug("{0} joined {1}".format(user.name, chat.name))

        if user.name == self.nick:
            # This is us
            self.logger.debug("We've joined {0}".format(chat.name))
            await chat.query_users()

            # TODO - fix
            self.chats[chat.name].joined = True
        else:
            chat_user.active = True

        event = self.IRCJoinEvent(self, chat, user)

    async def on_part(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)
        self.logger.debug("{0} left {1}".format(user.name, chat.name))

        if user.name == self.nick:
            # This is us
            self.logger.debug("We've left {0}".format(chat.name))

            # TODO - fix
            self.chats[chat.name].joined = False
        else:
            chat_user.active = False

        event = self.IRCLeaveEvent(self, chat, user)

    async def on_quit(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)
        self.logger.debug("{0} left {1}".format(user.name, chat.name))

        if user.name == self.nick:
            # This is us
            self.logger.debug("We've left {0}".format(chat.name))

            # TODO - fix
            self.chats[chat.name].joined = False
        else:
            chat_user.active = False

        event = self.IRCQuitEvent(self, chat, user)

    async def on_topic(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)
        topic = message.parameters[1][1:]

        self.logger.debug("{0} changed topic on {1} to {2}".format(
            user.name, chat.name, topic))

        event = self.IRCTopicEvent(self, chat, user, topic)

        chat.topic = topic

    async def on_topicreply(self, conn, message):
        chat = await self.chat_from_name(message.parameters[1])
        topic = message.parameters[2][1:]

        self.logger.debug("TOPIC for {0} is {1}".format(chat, topic))

        chat.topic = topic
        return

    async def on_nick(self, conn, message):
        user = await self.user_from_message(message)
        new_nick = message.parameters[0][1:]

        self.logger.debug("{0} changed nick to {1}".format(
            user.name, new_nick))

        event = self.IRCNickEvent(self, user, new_nick)
        user.name = new_nick

    async def on_privmsg(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)
        msg = self.IRCMessage(self, chat, user, message.parameters[1][1:])

        self.logger.debug("PRIVMSG received")

        await self.chats[chat.name].receive(message, chat_user)

    async def on_kick(self, conn, message):
        chat, user, chat_user = await self.get_chat_and_user_from_message(
            message)

        kicked_user = self.user_by_identifier(message.parameters[1])
        kick_message = message.parameters[2][1:]

        self.logger.debug("{0} kicked {1} from {2}".format(
            user.name, kicked_user, chat.name))

        # TODO - fix
        #chat.remove_user(kicked_user)

    async def on_whoreply(self, conn, message):
        chat = await self.chat_from_name(message.parameters[1])
        user = await self.user_from_tuple(message.parameters[5],
                                          message.parameters[2],
                                          message.parameters[3])
        user.real_name = message.parameters[7][3:]
        user.server = message.parameters[4]

        chat_user = await chat.get_chat_user(user)
        chat_user.active = True
        chat_user.operator = True

    async def on_invite(self, conn, message):
        self.logger.debug(message)

    async def on_mode(self, conn, message):
        self.logger.debug(message)

    async def quit(self):
        self.logger.debug("Quitting...")
        self.conn.quit()
        self.logger.debug("Disconnected!")
示例#11
0
def test_imports():
    from asyncirc.protocol import IrcProtocol

    proto = IrcProtocol([], '')
示例#12
0
class Server:
    def __init__(self,
                 loader: ModuleLoader,
                 config: ServerConfig,
                 loop=None) -> None:
        self._config = config
        self._modules = {}
        self._loader = loader
        self._loop = loop or asyncio.get_event_loop()
        self._conn = IrcProtocol(
            [IrcServer(config.address, config.port, config.ssl)],
            config.nick,
            loop=self._loop,
        )
        self._conn.register("*", self.on_server_message)
        self._active_channels = set()
        self._connected = False

    @property
    def config(self) -> ServerConfig:
        return self._config

    @property
    def address(self) -> str:
        return self.config.address

    @property
    def port(self) -> int:
        return self.config.port

    @property
    def close_future(self):
        return self._conn.quit_future

    @property
    def loop(self):
        return self._loop

    async def connect(self) -> None:
        await self.load_modules()
        await self._conn.connect()

    async def disconnect(self) -> None:
        log.debug("Disconnecting from %s", self.address)
        await self.unload_modules()
        self._connected = False
        self._conn.quit()

    async def reload(self, config: ServerConfig) -> None:
        """
        Reloads a server based on a new configuration.
        """
        assert (
            self.config.address == config.address
            and self.config.port == config.port
            and self.config.ssl == config.ssl
        ), "changing a connection must be done through the server manager"
        self._config = config
        await self.reload_modules()

    async def reload_modules(self) -> None:
        """
        Unloads, and then reloads all modules for this server.
        """
        log.debug("Reloading modules")
        unload = []
        for module in self._modules.values():
            if module.name in self.config.modules:
                if (module.config != self.config.modules[module.name]
                        or module.config.always_reload):
                    log.debug("Scheduling %s for reload", module.name)
                    unload += [module.name]
            else:
                unload += [module.name]

        await self.unload_modules(unload)
        await self.load_modules()
        self.match_channels()

    async def load_modules(self) -> None:
        """
        Loads all modules that have not yet been loaded for this server.
        """
        log.debug("Loading modules")
        for config in self.config.modules.values():
            if config.name in self._modules:
                continue
            on_load = None
            try:
                ctor = self._loader.load_module(config.name)
                loaded = ctor(config, self)
                on_load = self.loop.create_task(loaded.on_load())
                await on_load
                if self._connected:
                    await loaded.on_connect()
                self._modules[config.name] = loaded
            except KeyboardInterrupt:
                if on_load is not None:
                    on_load.cancel()
                raise
            except:
                log.exception("Could not load module %s", config.name)
                continue

    async def unload_modules(self,
                             which: Optional[Sequence[str]] = None) -> None:
        """
        Loads specified modules for this server.

        If nothing is specified, all modules are unloaded.

        If explicitly zero modules are specified (i.e. an empty set), no modules are unloaded.
        """
        log.debug("Unloading modules")
        if which is None:
            which = set(self._modules.keys())
        unloaded = []
        for module_name in which:
            self._loader.unload_module(module_name)
            unloaded += [self._modules.pop(module_name).on_unload()]

        await asyncio.gather(*unloaded)

    def match_channels(self):
        need = {
            chan
            for module in self._modules.values()
            for chan in module.config.channels
        }
        to_join = need - self._active_channels
        to_leave = self._active_channels - need
        for chan in to_join:
            self._conn.send("JOIN " + chan)
        for chan in to_leave:
            self._conn.send("PART " + chan)

    async def on_server_message(self, conn, msg) -> None:
        """
        Callback that is called whenever a message is received.
        """
        # log.debug("%s", msg)
        if msg.command == "001":
            self._connected = True
            await self.on_connect()
        elif msg.command == "KICK":
            await self.on_kick(msg)
        elif msg.command == "PART":
            await self.on_part(msg)
        elif msg.command == "JOIN":
            await self.on_join(msg)
        else:
            await self.on_message(msg)

    async def on_connect(self) -> None:
        """
        Callback that is run when this server connects.
        """
        self.match_channels()
        futures = [module.on_connect() for module in self._modules.values()]
        tasks = asyncio.gather(*futures, loop=self.loop)
        try:
            await tasks
        except KeyboardInterrupt:
            tasks.cancel()
            raise

    async def on_kick(self, msg):
        """
        Callback that is run when a user is kicked from a channel.
        """
        channel = msg.parameters[0]
        who = msg.parameters[1]
        if who == self.config.nick:
            who = None
            self._active_channels.remove(channel)
        futures = [
            module.on_kick(channel, who) for module in self._modules.values()
        ]
        tasks = asyncio.gather(*futures, loop=self.loop)
        try:
            await tasks
        except KeyboardInterrupt:
            tasks.cancel()
            raise
        except:
            log.exception("Error while handling kick callback")

        if who is None:
            self.loop.call_later(3.0, self.match_channels)

    async def on_part(self, msg):
        """
        Callback that is run when a user leaves a channel.
        """
        channel = msg.parameters[0]
        who = msg.prefix.nick
        if who == self.config.nick:
            who = None
            self._active_channels.remove(channel)
        futures = [
            module.on_part(channel, who) for module in self._modules.values()
        ]
        tasks = asyncio.gather(*futures, loop=self.loop)
        try:
            await tasks
        except KeyboardInterrupt:
            tasks.cancel()
            raise
        except:
            log.exception("Error while handling part callback")

        if who is None:
            self.loop.call_later(3.0, self.match_channels)

    async def on_join(self, msg):
        """
        Callback that is run when a user joins a channel.
        """
        channel = msg.parameters[0]
        who = msg.prefix.nick
        if who == self.config.nick:
            who = None
            self._active_channels.add(channel)
        futures = [
            module.on_join(channel, who) for module in self._modules.values()
        ]
        tasks = asyncio.gather(*futures, loop=self.loop)
        try:
            await tasks
        except KeyboardInterrupt:
            tasks.cancel()
            raise
        except:
            log.exception("Error while handling join callback")

        if who is None:
            self.loop.call_later(3.0, self.match_channels)

    async def on_message(self, msg):
        """
        Callback that is run when a PRIVMSG (i.e. a channel or private message) is received.
        """
        channel = msg.parameters[0]
        if channel not in self._active_channels:
            # private message to us
            channel = None
        if msg.prefix is None:
            return
        who = msg.prefix.nick
        if who == self.config.nick:
            who = None
        text = " ".join(msg.parameters[1:])
        futures = [
            module.on_message(channel, who, text)
            for module in self._modules.values() if module.should_handle(msg)
        ]
        tasks = asyncio.gather(*futures, loop=self.loop)
        try:
            await tasks
        except KeyboardInterrupt:
            tasks.cancel()
            raise
        except:
            log.exception("Error handling channel message")

    def send_message(self, target: str, message: str) -> None:
        """
        Sends a message to the server.
        """
        self._conn.send("PRIVMSG {} {}".format(target, message))