Esempio n. 1
0
 def KfpMigrate(database: SqliteDatabase):
     tables = database.get_tables()
     migrator = SqliteMigrator(database)
     if "rpgcharacter" in tables:
         columns = database.get_columns("rpgcharacter")
         if not KfpMigrator.hasColumn("retired", columns):
             retiredField = BooleanField(default=False)
             migrate(
                 migrator.add_column("rpgcharacter", "retired", retiredField)
             )
         if not KfpMigrator.hasColumn("last_attack", columns):
             lastAttackField = DateTimeField(default=datetime.now() + timedelta(days=-1))
             migrate(
                 migrator.add_column("rpgcharacter", "last_attack", lastAttackField)
             )
     if "member" in tables:
         columns = database.get_columns("member")
         if not KfpMigrator.hasColumn("token", columns):
             tokenField = BigIntegerField(default=100)
             migrate(
                 migrator.add_column("member", 'token', tokenField)
             )
     if "channel" in tables:
         columns = database.get_columns("channel")
         if not KfpMigrator.hasColumn("channel_id", columns):
             guildIdField = IntegerField(default=-1)
             migrate(
                 migrator.add_column('channel', 'channel_guild_id', guildIdField),
                 migrator.rename_column('channel', 'channel_discord_id', 'channel_id'),
             )
     if "item" in tables:
         columns = database.get_columns("item")
         if KfpMigrator.hasColumn("hidden", columns):                
             migrate(
                 migrator.drop_column('item', 'hidden'),
             )
         if KfpMigrator.hasColumn("buff_type", columns):                
             migrate(
                 migrator.drop_column('item', 'buff_type'),
             )
         if KfpMigrator.hasColumn("buff_value", columns):                
             migrate(
                 migrator.drop_column('item', 'buff_value'),
             )
         if not KfpMigrator.hasColumn("type", columns):
             typeField = CharField(default=ItemType.NONE)
             migrate(
                 migrator.add_column('item', 'type', typeField),
             )
         if not KfpMigrator.hasColumn("buff", columns):
             buff = BuffField(default=Buff(BuffType.NONE, 0, -1))
             migrate(
                 migrator.add_column('item', 'buff', buff),
             )
         if not KfpMigrator.hasColumn("description", columns):
             description = CharField(default="")
             migrate(
                 migrator.add_column('item', 'description', description),
             )
     return True
Esempio n. 2
0
def get_table(table_name: str,
              database: peewee.SqliteDatabase) -> peewee.Table:

    columns = tuple(column.name for column in database.get_columns(table_name))
    table = peewee.Table(table_name, columns)
    table.bind(database)

    return table
Esempio n. 3
0
from peewee import SqliteDatabase

db = SqliteDatabase('c:\\bench\\allocat\\allocat.db')
diff_fields = [f.name for f in db.get_columns('payrecords')]
rec_fields = diff_fields[3:-1]


def main():
    import sys
    from PyQt5.QtWidgets import QApplication
    from controllers import PayrunController

    app = QApplication(sys.argv)
    ssh_file = 'views/allocat.stylesheet'
    with open(ssh_file, "r") as fh:
        app.setStyleSheet(fh.read())

    ctrlr = PayrunController()
    ctrlr.runit()
    sys.exit(app.exec_())


if __name__ == '__main__':
    main()
Esempio n. 4
0
class DatabaseManager:
    logger = logging.getLogger(__name__)

    def __init__(self, channel: EFBChannel):
        base_path = utils.get_data_path(channel.channel_id)

        self.db = SqliteDatabase(str(base_path / 'tgdata.db'))

        self.db.connect()

        class BaseModel(Model):
            class Meta:
                database = self.db

        class ChatAssoc(BaseModel):
            master_uid = TextField()
            slave_uid = TextField()

        class MsgLog(BaseModel):
            master_msg_id = TextField(unique=True, primary_key=True)
            master_msg_id_alt = TextField(null=True)
            slave_message_id = TextField()
            text = TextField()
            slave_origin_uid = TextField()
            slave_origin_display_name = TextField(null=True)
            slave_member_uid = TextField(null=True)
            slave_member_display_name = TextField(null=True)
            media_type = TextField(null=True)
            mime = TextField(null=True)
            file_id = TextField(null=True)
            msg_type = TextField()
            sent_to = TextField()
            time = DateTimeField(default=datetime.datetime.now, null=True)

        class SlaveChatInfo(BaseModel):
            slave_channel_id = TextField()
            slave_channel_emoji = CharField()
            slave_chat_uid = TextField()
            slave_chat_name = TextField()
            slave_chat_alias = TextField(null=True)
            slave_chat_type = CharField()

        self.BaseModel = BaseModel
        self.ChatAssoc = ChatAssoc
        self.MsgLog = MsgLog
        self.SlaveChatInfo = SlaveChatInfo

        if not ChatAssoc.table_exists():
            self._create()
        elif "file_id" not in {i.name for i in self.db.get_columns("MsgLog")}:
            self._migrate(0)

    def _create(self):
        """
        Initializing tables.
        """
        self.db.execute_sql("PRAGMA journal_mode = OFF")
        self.db.create_tables([self.ChatAssoc, self.MsgLog, self.SlaveChatInfo])

    def _migrate(self, i):
        """
        Run migrations.

        Args:
            i: Migration ID

        Returns:
            False: when migration ID is not found
        """
        migrator = SqliteMigrator(self.db)
        if i >= 0:
            # Migration 0: Add media file ID and editable message ID
            # 2019JAN08
            migrate(
                migrator.add_column("msglog", "file_id", self.MsgLog.file_id),
                migrator.add_column("msglog", "media_type", self.MsgLog.media_type),
                migrator.add_column("msglog", "mime", self.MsgLog.mime),
                migrator.add_column("msglog", "master_msg_id_alt", self.MsgLog.master_msg_id_alt)
            )
        # if i == 0:
        #     # Migration 0: Added Time column in MsgLog table.
        #     # 2016JUN15
        #     migrate(migrator.add_column("msglog", "time", DateTimeField(default=datetime.datetime.now, null=True)))
        # elif i == 1:
        #     # Migration 1:
        #     # Add table: SlaveChatInfo
        #     # 2017FEB25
        #     SlaveChatInfo.create_table()
        #     migrate(migrator.add_column("msglog", "slave_message_id", CharField(default="__none__")))
        #
        # else:
        return False

    def add_chat_assoc(self, master_uid, slave_uid, multiple_slave=False):
        """
        Add chat associations (chat links).
        One Master channel with many Slave channel.

        Args:
            master_uid (str): Master channel UID ("%(chat_id)s")
            slave_uid (str): Slave channel UID ("%(channel_id)s.%(chat_id)s")
        """
        if not multiple_slave:
            self.remove_chat_assoc(master_uid=master_uid)
        self.remove_chat_assoc(slave_uid=slave_uid)
        return self.ChatAssoc.create(master_uid=master_uid, slave_uid=slave_uid)

    def remove_chat_assoc(self, master_uid=None, slave_uid=None):
        """
        Remove chat associations (chat links).
        Only one parameter is to be provided.

        Args:
            master_uid (str): Master channel UID ("%(chat_id)s")
            slave_uid (str): Slave channel UID ("%(channel_id)s.%(chat_id)s")
        """
        try:
            if bool(master_uid) == bool(slave_uid):
                raise ValueError("Only one parameter is to be provided.")
            elif master_uid:
                return self.ChatAssoc.delete().where(self.ChatAssoc.master_uid == master_uid).execute()
            elif slave_uid:
                return self.ChatAssoc.delete().where(self.ChatAssoc.slave_uid == slave_uid).execute()
        except DoesNotExist:
            return 0

    def get_chat_assoc(self, master_uid: str = None, slave_uid: str = None) -> List[str]:
        """
        Get chat association (chat link) information.
        Only one parameter is to be provided.

        Args:
            master_uid (str): Master channel UID ("%(chat_id)s")
            slave_uid (str): Slave channel UID ("%(channel_id)s.%(chat_id)s")

        Returns:
            list: The counterpart ID.
        """
        try:
            if bool(master_uid) == bool(slave_uid):
                raise ValueError("Only one parameter is to be provided.")
            elif master_uid:
                slaves = self.ChatAssoc.select().where(self.ChatAssoc.master_uid == master_uid)
                if len(slaves) > 0:
                    return [i.slave_uid for i in slaves]
                else:
                    return []
            elif slave_uid:
                masters = self.ChatAssoc.select().where(self.ChatAssoc.slave_uid == slave_uid)
                if len(masters) > 0:
                    return [i.master_uid for i in masters]
                else:
                    return []
        except DoesNotExist:
            return []

    def get_last_msg_from_chat(self, chat_id):
        """Get last message from the selected chat from Telegram

        Args:
            chat_id (int|str): Telegram chat ID

        Returns:
            MsgLog: The last message from the chat
        """
        try:
            return self.MsgLog.select().where(self.MsgLog.master_msg_id.startswith("%s." % chat_id)).order_by(
                self.MsgLog.time.desc()).first()
        except DoesNotExist:
            return None

    def add_msg_log(self, **kwargs):
        """
        Add an entry to message log.

        Display name is defined as `alias or name`.

        Args:
            master_msg_id (str): Telegram message ID ("%(chat_id)s.%(msg_id)s")
            text (str): String representation of the message
            slave_origin_uid (str): Slave chat ID ("%(channel_id)s.%(chat_id)s")
            msg_type (str): String of the message type.
            sent_to (str): "master" or "slave"
            slave_origin_display_name (str): Display name of slave chat.
            slave_member_uid (str|None):
                User ID of the slave chat member (sender of the message, for group chat only).
                ("%(channel_id)s.%(chat_id)s"), None if not available.
            slave_member_display_name (str|None):
                Display name of the member, None if not available.
            update (bool): Update a previous record. Default: False.
            slave_message_id (str): the corresponding message uid from slave channel.

        Returns:
            MsgLog: The added/updated entry.
        """
        master_msg_id = kwargs.get('master_msg_id')
        text = kwargs.get('text')
        slave_origin_uid = kwargs.get('slave_origin_uid')
        msg_type = kwargs.get('msg_type')
        sent_to = kwargs.get('sent_to')
        slave_origin_display_name = kwargs.get('slave_origin_display_name', None)
        slave_member_uid = kwargs.get('slave_member_uid', None)
        slave_member_display_name = kwargs.get('slave_member_display_name', None)
        slave_message_id = kwargs.get('slave_message_id')
        master_msg_id_alt = kwargs.get('master_msg_id_alt', None)
        media_type = kwargs.get('media_type', None)
        file_id = kwargs.get('file_id', None)
        mime = kwargs.get('mime', None)
        update = kwargs.get('update', False)
        if update:
            msg_log = self.MsgLog.get(self.MsgLog.master_msg_id == master_msg_id)
            msg_log.text = text or msg_log.text
            msg_log.msg_type = msg_type or msg_log.msg_type
            msg_log.sent_to = sent_to or msg_log.sent_to
            msg_log.slave_origin_uid = slave_origin_uid or msg_log.slave_origin_uid
            msg_log.slave_origin_display_name = slave_origin_display_name or msg_log.slave_origin_display_name
            msg_log.slave_member_uid = slave_member_uid or msg_log.slave_member_uid
            msg_log.slave_member_display_name = slave_member_display_name or msg_log.slave_member_display_name
            msg_log.slave_message_id = slave_message_id or msg_log.slave_message_id
            msg_log.master_msg_id_alt = master_msg_id_alt
            msg_log.media_type = media_type or msg_log.media_type
            msg_log.file_id = file_id or msg_log.file_id
            msg_log.mime = mime or msg_log.mime
            msg_log.save()
            return msg_log
        else:
            return self.MsgLog.create(master_msg_id=master_msg_id,
                                      slave_message_id=slave_message_id,
                                      text=text,
                                      slave_origin_uid=slave_origin_uid,
                                      msg_type=msg_type,
                                      sent_to=sent_to,
                                      slave_origin_display_name=slave_origin_display_name,
                                      slave_member_uid=slave_member_uid,
                                      slave_member_display_name=slave_member_display_name,
                                      master_msg_id_alt=master_msg_id_alt,
                                      media_type=media_type,
                                      file_id=file_id,
                                      mime=mime
                                      )

    def get_msg_log(self,
                    master_msg_id: Optional[str] = None,
                    slave_msg_id: Optional[str] = None,
                    slave_origin_uid: Optional[str] = None) -> Optional['MsgLog']:
        """Get message log by message ID.

        Args:
            master_msg_id: Telegram message ID in string
            slave_msg_id: Slave message identifier in string
            slave_origin_uid: Slave chat identifier in string

        Returns:
            MsgLog|None: The queried entry, None if not exist.
        """
        if (master_msg_id and (slave_msg_id or slave_origin_uid)) \
                or not (master_msg_id or (slave_msg_id or slave_origin_uid)):
            raise ValueError('master_msg_id and slave_msg_id is mutual exclusive')
        if not master_msg_id and not (slave_msg_id and slave_origin_uid):
            raise ValueError('slave_msg_id and slave_origin_uid must exists together.')
        try:
            if master_msg_id:
                return self.MsgLog.select().where(self.MsgLog.master_msg_id == master_msg_id) \
                    .order_by(self.MsgLog.time.desc()).first()
            else:
                return self.MsgLog.select().where((self.MsgLog.slave_message_id == slave_msg_id) &
                                                  (self.MsgLog.slave_origin_uid == slave_origin_uid)
                                                  ).order_by(self.MsgLog.time.desc()).first()
        except DoesNotExist:
            return None

    def delete_msg_log(self,
                       master_msg_id: Optional[str] = None,
                       slave_msg_id: Optional[str] = None,
                       slave_origin_uid: Optional[str] = None):
        """Remove a message log by message ID.

        Args:
            master_msg_id: Telegram message ID in string
            slave_msg_id: Slave message identifier in string
            slave_origin_uid: Slave chat identifier in string
        """
        if (master_msg_id and (slave_msg_id or slave_origin_uid)) \
                or not (master_msg_id or (slave_msg_id or slave_origin_uid)):
            raise ValueError('master_msg_id and slave_msg_id is mutual exclusive')
        if not master_msg_id and not (slave_msg_id and slave_origin_uid):
            raise ValueError('slave_msg_id and slave_origin_uid must exists together.')
        try:
            if master_msg_id:
                self.MsgLog.delete().where(self.MsgLog.master_msg_id == master_msg_id).execute()
            else:
                self.MsgLog.delete().where((self.MsgLog.slave_message_id == slave_msg_id) &
                                           (self.MsgLog.slave_origin_uid == slave_origin_uid)
                                           ).execute()
        except DoesNotExist:
            return

    def get_slave_chat_info(self, slave_channel_id=None, slave_chat_uid=None) -> Optional['SlaveChatInfo']:
        """
        Get cached slave chat info from database.

        Returns:
            SlaveChatInfo|None: The matching slave chat info, None if not exist.
        """
        if slave_channel_id is None or slave_chat_uid is None:
            raise ValueError("Both slave_channel_id and slave_chat_id should be provided.")
        try:
            return self.SlaveChatInfo.select()\
                .where((self.SlaveChatInfo.slave_channel_id == slave_channel_id) &
                       (self.SlaveChatInfo.slave_chat_uid == slave_chat_uid)).first()
        except DoesNotExist:
            return None

    def set_slave_chat_info(self,
                            slave_channel_id=None,
                            slave_channel_name=None,
                            slave_channel_emoji=None,
                            slave_chat_uid=None,
                            slave_chat_name=None,
                            slave_chat_alias="",
                            slave_chat_type=None):
        """
        Insert or update slave chat info entry

        Args:
            slave_channel_id (str): Slave channel ID
            slave_channel_name (str): Slave channel name
            slave_channel_emoji (str): Slave channel emoji
            slave_chat_uid (str): Slave chat UID
            slave_chat_name (str): Slave chat name
            slave_chat_alias (str): Slave chat alias, "" (empty string) if not available
            slave_chat_type (channel.ChatType): Slave chat type

        Returns:
            SlaveChatInfo: The inserted or updated row
        """
        if self.get_slave_chat_info(slave_channel_id=slave_channel_id, slave_chat_uid=slave_chat_uid):
            chat_info = self.SlaveChatInfo.get(self.SlaveChatInfo.slave_channel_id == slave_channel_id,
                                               self.SlaveChatInfo.slave_chat_uid == slave_chat_uid)
            chat_info.slave_channel_name = slave_channel_name
            chat_info.slave_channel_emoji = slave_channel_emoji
            chat_info.slave_chat_name = slave_chat_name
            chat_info.slave_chat_alias = slave_chat_alias
            chat_info.slave_chat_type = slave_chat_type.value
            chat_info.save()
            return chat_info
        else:
            return self.SlaveChatInfo.create(slave_channel_id=slave_channel_id,
                                             slave_channel_name=slave_channel_name,
                                             slave_channel_emoji=slave_channel_emoji,
                                             slave_chat_uid=slave_chat_uid,
                                             slave_chat_name=slave_chat_name,
                                             slave_chat_alias=slave_chat_alias,
                                             slave_chat_type=slave_chat_type.value)

    def delete_slave_chat_info(self, slave_channel_id, slave_chat_uid):
        return self.SlaveChatInfo.delete()\
            .where((self.SlaveChatInfo.slave_channel_id == slave_channel_id) &
                   (self.SlaveChatInfo.slave_chat_uid == slave_chat_uid)).execute()

    def get_recent_slave_chats(self, master_chat_id, limit=5):
        return [i.slave_origin_uid for i in
                self.MsgLog.select(self.MsgLog.slave_origin_uid)
                    .distinct()
                    .where(self.MsgLog.master_msg_id.startswith("%s." % master_chat_id))
                    .order_by(self.MsgLog.time.desc())
                    .limit(limit)]
Esempio n. 5
0
class Migrations(object):
    """
    Migrations

    Handle all migrations during application start.
    """

    logger = None
    database = None

    def __init__(self, config):
        unmanic_logging = unlogger.UnmanicLogger.__call__()
        self.logger = unmanic_logging.get_logger(__class__.__name__)

        # Based on configuration, select database to connect to.
        if config['TYPE'] == 'SQLITE':
            # Create SQLite directory if not exists
            db_file_directory = os.path.dirname(config['FILE'])
            if not os.path.exists(db_file_directory):
                os.makedirs(db_file_directory)
            self.database = SqliteDatabase(config['FILE'])

            self.router = Router(database=self.database,
                                 migrate_table='migratehistory_{}'.format(
                                     config.get('MIGRATIONS_HISTORY_VERSION')),
                                 migrate_dir=config.get('MIGRATIONS_DIR'),
                                 logger=self.logger)

    def __log(self, message, level='info'):
        if self.logger:
            getattr(self.logger, level)(message)
        else:
            print(message)

    def __run_all_migrations(self):
        """
        Run all new migrations.
        Migrations that have already been run will be ignored.

        :return:
        """
        self.router.run()

    def update_schema(self):
        """
        Updates the Unmanic database schema.

        Newly added tables/models and columns/fields will be automatically generated by this function.
        This way we do not need to create a migration script unless we:
            - rename a column/field
            - delete a column/field
            - delete a table/model

        :return:
        """
        # Fetch all model classes
        all_models = []
        all_base_models = []
        for model in list_all_models():
            imported_model = getattr(
                importlib.import_module("unmanic.libs.unmodels"), model)
            if inspect.isclass(imported_model) and issubclass(
                    imported_model, BaseModel):
                # Add this model to both the 'all_models' list and our list of base models
                all_models.append(imported_model)
                all_base_models.append(imported_model)
            elif inspect.isclass(imported_model) and issubclass(
                    imported_model, Model):
                # If the model is not one of the base models, it is an in-build model from peewee.
                # For, this list of models we will not run a migration, but we will still ensure that the
                #   table is created in the DB
                all_models.append(imported_model)
                pass

        # Start by creating all models
        self.__log("Initialising database tables")
        try:
            with self.database.transaction():
                for model in all_models:
                    self.router.migrator.create_table(model)
                self.router.migrator.run()
        except Exception:
            self.database.rollback()
            self.__log("Initialising tables failed", level='exception')
            raise

        # Migrations will only be used for removing obsolete columns
        self.__run_all_migrations()

        # Newly added fields can be auto added with this function... no need for a migration script
        # Ensure all files are also present for each of the model classes
        self.__log("Updating database fields")
        for model in all_base_models:
            # Fetch all peewee fields for the model class
            # https://stackoverflow.com/questions/22573558/peewee-determining-meta-data-about-model-at-run-time
            fields = model._meta.fields
            # loop over the fields and ensure each on exists in the table
            field_keys = [f for f in fields]
            for fk in field_keys:
                field = fields.get(fk)
                if isinstance(field, Field):
                    if not any(f for f in self.database.get_columns(
                            model._meta.name) if f.name == field.name):
                        # Field does not exist in DB table
                        self.__log("Adding missing column")
                        try:
                            with self.database.transaction():
                                self.router.migrator.add_columns(
                                    model, **{field.name: field})
                                self.router.migrator.run()
                        except Exception:
                            self.database.rollback()
                            self.__log("Update failed", level='exception')
                            raise