Beispiel #1
0
def test_exceptions():
    d = LRU(1)

    d['a'] = 1

    assert len(d) == 1

    assert d == {'a': 1}

    with pytest.raises(NotImplementedError):
        del d['a']

    with pytest.raises(NotImplementedError):
        d.pop('a')
Beispiel #2
0
 def test_pop(self):
     l = LRU(2)
     v = '2' * 4096
     l[1] = '1'
     l[2] = v
     val = l.pop(1)
     self.assertEqual('1', val)
     self.assertEqual((1, 0), l.get_stats())
     val = l.pop(2, 'not used')
     self.assertEqual(v, val)
     del val
     self.assertTrue(v)
     self.assertEqual((2, 0), l.get_stats())
     val = l.pop(3, '3' * 4096)
     self.assertEqual('3' * 4096, val)
     self.assertEqual((2, 1), l.get_stats())
     self.assertEqual(0, len(l))
     with self.assertRaises(KeyError) as ke:
         l.pop(4)
         self.assertEqual(4, ke.args[0])
     self.assertEqual((2, 2), l.get_stats())
     self.assertEqual(0, len(l))
     with self.assertRaises(TypeError):
         l.pop()
Beispiel #3
0
class StarboardEntries:
    """A way of managing starboard entries.
    Sort of like an ORM, but also not fully."""

    _pool: asyncpg.Pool = attr.ib()
    # note: entry cache isn't really a dict, but for typehinting purposes this works
    _entry_cache: typing.Dict[int, StarboardEntry] = attr.ib()
    _sql_loop_task: asyncio.Task = attr.ib()
    _sql_queries: cclass.SetUpdateAsyncQueue = attr.ib()

    def __init__(self, pool: asyncpg.Pool, cache_size: int = 200):
        self._pool = pool
        self._entry_cache = LRU(
            cache_size
        )  # the 200 should be raised as the bot grows bigger
        self._sql_queries = cclass.SetUpdateAsyncQueue()

        loop = asyncio.get_event_loop()
        self._sql_loop_task = loop.create_task(self._sql_loop())

    def stop(self):
        """Stops the SQL task loop."""
        self._sql_loop_task.cancel()

    async def _sql_loop(self):
        """Actually runs SQL updating, hopefully one after another.

        Saves speed on adding, deleting, and updating by offloading
        this step here."""
        try:
            while True:
                entry = await self._sql_queries.get()
                logging.getLogger("discord").debug(f"Running {entry.query}.")
                await self._pool.execute(entry.query, timeout=60, *entry.args)
                self._sql_queries.task_done()
        except asyncio.CancelledError:
            pass

    def _get_required_from_entry(self, entry: StarboardEntry):
        """Transforms data into the form needed for databases."""
        return (
            entry.ori_mes_id,
            entry.ori_chan_id,
            entry.star_var_id,
            entry.starboard_id,
            entry.author_id,
            list(entry.ori_reactors),
            list(entry.var_reactors),
            entry.guild_id,
            entry.forced,
            entry.frozen,
            entry.trashed,
        )

    def _str_builder_to_insert(
        self, str_builder: typing.List[str], entry: StarboardEntry
    ):
        """Takes data from a string builder list and eventually
        puts the data needed into the _sql_queries variable."""
        query = "".join(str_builder)
        args = self._get_required_from_entry(entry)
        self._sql_queries.put_nowait(StarboardSQLEntry(query, args))

    def _handle_upsert(self, entry: StarboardEntry):
        """Upserts an entry by using an INSERT with an ON CONFLICT cause.
        This is a PostgreSQL-specific feature, so that's nice!"""
        str_builder = [
            "INSERT INTO starboard(ori_mes_id, ori_chan_id, star_var_id, ",
            "starboard_id, author_id, ori_reactors, var_reactors, ",
            "guild_id, forced, frozen, trashed) VALUES($1, $2, $3, $4, ",
            "$5, $6, $7, $8, $9, $10, $11) ON CONFLICT (ori_mes_id) DO UPDATE ",
            "SET ori_chan_id = $2, star_var_id = $3, starboard_id = $4, ",
            "author_id = $5, ori_reactors = $6, var_reactors = $7, guild_id = $8, ",
            "forced = $9, frozen = $10, trashed = $11",
        ]
        self._str_builder_to_insert(str_builder, entry)

    def upsert(self, entry: StarboardEntry):
        """Either adds or updates an entry in the collection of entries."""
        temp_dict = {entry.ori_mes_id: entry}
        if entry.star_var_id:
            temp_dict[entry.star_var_id] = entry

        self._entry_cache.update(**temp_dict)  # type: ignore this is valid i promise
        self._handle_upsert(entry)

    def delete(self, entry_id: int):
        """Removes an entry from the collection of entries."""
        self._entry_cache.pop(entry_id, None)
        self._sql_queries.put_nowait(
            StarboardSQLEntry("DELETE FROM starboard WHERE ori_mes_id = $1", [entry_id])
        )

    async def get(
        self, entry_id: int, check_for_var: bool = False
    ) -> typing.Optional[StarboardEntry]:
        """Gets an entry from the collection of entries."""
        entry = None

        if self._entry_cache.has_key(entry_id):  # type: ignore
            entry = self._entry_cache[entry_id]
        else:
            entry = discord.utils.find(
                lambda e: e and e.star_var_id == entry_id, self._entry_cache.values()
            )

        if not entry:
            async with self._pool.acquire() as conn:
                data = await conn.fetchrow(
                    f"SELECT * FROM starboard WHERE ori_mes_id = {entry_id} OR"
                    f" star_var_id = {entry_id}"
                )
                if data:
                    entry = StarboardEntry.from_row(data)
                    self._entry_cache[entry_id] = entry

        if entry and check_for_var and not entry.star_var_id:
            return None

        return entry

    async def select_query(self, query: str):
        """Selects the starboard database directly for entries based on the query."""
        async with self._pool.acquire() as conn:
            data = await conn.fetch(f"SELECT * FROM starboard WHERE {query}")

            if not data:
                return None
            return tuple(StarboardEntry.from_row(row) for row in data)

    async def raw_query(self, query: str):
        """Runs the raw query against the pool, assuming the results are starboard entries."""
        async with self._pool.acquire() as conn:
            data = await conn.fetch(query)

            if not data:
                return None
            return tuple(StarboardEntry.from_row(row) for row in data)

    async def super_raw_query(self, query: str):
        """You want a raw query? You'll get one."""
        async with self._pool.acquire() as conn:
            return await conn.fetch(query)

    async def query_entries(
        self, seperator: str = "AND", **conditions: typing.Dict[str, str]
    ) -> typing.Optional[typing.Tuple[StarboardEntry, ...]]:
        """Queries entries based on conditions provided.

        For example, you could do `query_entries(guild_id=143425)` to get
        entries with that guild id."""
        sql_conditions: list[str] = [
            f"{key} = {value}" for key, value in conditions.items()
        ]
        combined_statements = f" {seperator} ".join(sql_conditions)

        async with self._pool.acquire() as conn:
            data = await conn.fetch(
                f"SELECT * FROM starboard WHERE {combined_statements}"
            )

            if not data:
                return None
            return tuple(StarboardEntry.from_row(row) for row in data)

    async def get_random(self, guild_id: int) -> typing.Optional[StarboardEntry]:
        """Gets a random entry from a guild."""
        # query adapted from
        # https://github.com/Rapptz/RoboDanny/blob/1fb95d76d1b7685e2e2ff950e11cddfc96efbfec/cogs/stars.py#L1082
        query = """SELECT *
                   FROM starboard
                   WHERE guild_id=$1
                   AND star_var_id IS NOT NULL
                   OFFSET FLOOR(RANDOM() * (
                       SELECT COUNT(*)
                       FROM starboard
                       WHERE guild_id=$1
                       AND star_var_id IS NOT NULL
                   ))
                   LIMIT 1
                """

        async with self._pool.acquire() as conn:
            data = await conn.fetchrow(query, guild_id)
            if not data:
                return None
            return StarboardEntry.from_row(data)
Beispiel #4
0
class MTPStorage(MTPEntry, MTPRefresh):
   def __init__(self, mtp, pstorage=None):
      global PATH_CACHE_SIZE
      MTPRefresh.__init__(self)
      self.mtp = mtp
      self.libmtp = mtp.libmtp
      self.open_device = mtp.open_device
      self.directories = None
      self.contents = LRU(PATH_CACHE_SIZE)
      if pstorage is None:
         MTPEntry.__init__(self, -3, '/')
         self.storage = None
         self.directories = []
         for dirname in self.mtp.get_storage_descriptions():
            #def __init__(self, path, id=-2, storageid=-2, folderid=-2, mtp=None, timestamp=0, is_refresh=True):
            self.directories.append(MTPFolder(path=dirname, id= -3, storageid= -3, folderid= -2, is_refresh=False)) 
         self.root = None
         self.contents[utf8(os.sep)] = self
      else:         
         self.storage = pstorage
         storage = pstorage.contents
         self.type = storage.StorageType
         self.freespace = storage.FreeSpaceInBytes
         self.capacity = storage.MaxCapacity
         path = os.sep + storage.StorageDescription
         MTPEntry.__init__(self, storage.id, path, storageid=None, folderid=0)
         self.root = MTPFolder(path=path, id=0, storageid=storage.id, folderid=0, mtp=self.mtp)
         self.contents[utf8(path)] = self.root
      
   def is_directory(self):
      return True     
   
   def get_attributes(self):
      return { 'st_atime': self.timestamp, 'st_ctime': self.timestamp, 'st_gid': os.getgid(),
               'st_mode': stat.S_IFDIR | 0755, 'st_mtime': self.timestamp, 'st_nlink': 1,
               'st_size': 0, 'st_uid': os.getuid() }
      
   def get_directories(self):
      if self.directories is None:
         if self.root is None:
            return ()
         else:
            return self.root.get_directories()
      else:
         return self.directories
      
   def get_files(self):
      if self.root is None:
         return ()
      else:
         return self.root.get_files()
      
   def add_file(self, file):
      if not self.root is None:
         self.root.add_file(self, file)
               
   def __str__(self):            
      s = "MTPStorage %s: id=%d, device=%s%s" % (self.name, self.id, self.open_device, os.linesep)      
      return s
         
   def find_entry(self, path):
      path = utf8(path)
      self.log.debug('find_entry(%s)' % (path,))
      try:
         if path.strip() == '':
            path = os.sep + self.name 
         entry = self.contents.get(path)
         if entry is None:         
            components = [comp for comp in path.split(os.sep) if len(comp.strip()) != 0]
            if len(components) == 0:
               return None
            if components[0] != self.name:
               raise LookupError('Invalid storage (expected %s, was %s)' % (self.name, components[0]))
            entry = self.__find_entry(self.root, components[1:])
         else:
            if entry.is_directory() and entry.must_refresh:
               entry.refresh()
         return entry
      except:
         self.log.exception("")
         return None
   
   def __find_entry(self, entry, components):
      self.log.debug("__find_entry(%s, %s)" % (entry, str(components)))
      if len(components) == 0:
         return entry
      name = components[0]
      path = entry.path + os.sep + name 
      en = self.contents.get(utf8(path))
      if not en is None:
         if en.is_directory() and en.must_refresh:
            en.refresh()
         return self.__find_entry(en, components[1:])
      en = entry.find_directory(name)
      if not en is None and en.is_directory():
         self.contents[utf8(path)] = en
         if en.must_refresh:
            en.refresh()
         return self.__find_entry(en, components[1:])
      return entry.find_file(name)
   
   def remove_entry(self, path):
      try:
         return self.contents.pop(utf8(path))
      except:
         #self.log.warn('MTPStorage.remove_entry: %s not found' % (path,))
         return None
      
   def refresh(self):
      if not self.root is None and self.must_refresh:
         self.must_refresh = not self.root.refresh()
   
   def close(self):
      pass