Esempio n. 1
0
    def test_session_scope(self) -> Session:
        """Provides a transactional scope around a series of operations that will be rolled back at the end.

        Note:
            MySQL storage engine MyISAM does not support rollback transactions, so all the modifications
            performed to the database will persist.

        """
        # Connect to the database
        connection = self.connect()
        # Begin a non-ORM transaction
        transaction = connection.begin()
        # Bind an individual Session to the connection
        session = Session(bind=connection)
        # Start the session in a SAVEPOINT
        session.begin_nested()
        # Define a new transaction event
        @event.listens_for(session, "after_transaction_end")
        def restart_savepoint(session, transaction):  # pylint: disable=unused-variable
            """Reopen a SAVEPOINT whenever the previous one ends."""
            if transaction.nested and not transaction._parent.nested:  # pylint: disable=protected-access
                # Ensure that state is expired the same way session.commit() at the top level normally does
                session.expire_all()
                session.begin_nested()

        try:
            yield session
        finally:
            # Whatever happens, make sure the session and connection are closed, rolling back everything done
            # with the session (including calls to commit())
            session.close()
            transaction.rollback()
            connection.close()
Esempio n. 2
0
def db_session(test_db):
    connection = test_db.connect()
    transaction = connection.begin()
    session = Session(autocommit=False, autoflush=False, bind=connection)
    try:
        session.begin_nested()
        yield session
    finally:
        session.close()
        transaction.rollback()
        connection.close()
Esempio n. 3
0
 def new_trid(cls, session: Session, pid: int) -> int:
     """
     We check for existence by inserting and asking the database if it's
     happy, not by asking the database if it exists (since other processes
     may be doing the same thing at the same time).
     """
     while True:
         session.begin_nested()
         candidate = random.randint(1, MAX_TRID)
         log.debug("Trying candidate TRID: {}".format(candidate))
         obj = cls(pid=pid, trid=candidate)
         try:
             session.add(obj)
             session.commit()  # may raise IntegrityError
             return candidate
         except IntegrityError:
             session.rollback()
Esempio n. 4
0
def setup_module():
    global transaction, connection, engine

    engine = create_engine('postgresql:///recordsheet_test')
    connection = engine.connect()
    transaction = connection.begin()
    Base.metadata.create_all(connection)

    #insert some data
    inner_tr = connection.begin_nested()
    ses = Session(connection)
#    ses.begin_nested()
    ses.add(dbmodel.Account(name='TEST01', desc='test account 01'))
    ses.add(dbmodel.Account(name='TEST02', desc='test account 02'))
    user = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=False)
    ses.add(user)
    lockeduser = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=True)
    ses.add(lockeduser)

    batch = dbmodel.Batch(user=user)
    ses.add(batch)
    jrnl = dbmodel.Journal(memo='test', batch=batch,
                datetime='2016-06-05 14:09:00-05')
    ses.add(jrnl)
    ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1,
                        journal=jrnl))
    ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2,
                        journal=jrnl))
    ses.commit()
    # mock a sessionmaker so all querys are in this transaction
    dbapi._session = lambda: ses
    ses.begin_nested()

    @event.listens_for(ses, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            ses.begin_nested()
Esempio n. 5
0
    def new_trid(cls, session: Session, pid: Union[int, str]) -> int:
        """
        Creates a new TRID: a random integer that's not yet been used as a
        TRID.

        We check for existence by inserting and asking the database if it's
        happy, not by asking the database if it exists (since other processes
        may be doing the same thing at the same time).
        """
        while True:
            session.begin_nested()
            candidate = random.randint(1, MAX_TRID)
            log.debug(f"Trying candidate TRID: {candidate}")
            # noinspection PyArgumentList
            obj = cls(pid=pid, trid=candidate)
            try:
                session.add(obj)
                session.commit()  # may raise IntegrityError
                return candidate
            except IntegrityError:
                session.rollback()
def process_file(session: Session, file: File,
                 callback: Callable[[], None]) -> bool:
    if file.processing_started_at:
        return False

    # Claim this file by updating the `processing_started_at` timestamp in such
    # a way that it must not have been set before.
    processing_started_at = datetime.datetime.now(timezone.utc)
    result = session.execute(
        update(File.__table__)  # pylint: disable=no-member
        .where(File.id == file.id).where(
            File.processing_started_at.is_(None)).values(
                processing_started_at=processing_started_at))
    if result.rowcount == 0:
        return False
    # If we got this far, `file` is ours to process.
    try:
        session.begin_nested()
        callback()
        file.processing_started_at = processing_started_at
        file.processing_completed_at = datetime.datetime.now(timezone.utc)
        session.add(file)
        session.commit()
        return True
    except Exception as error:
        session.rollback()
        file.processing_started_at = processing_started_at
        file.processing_completed_at = datetime.datetime.now(timezone.utc)
        # Some errors stringify nicely, some don't (e.g. StopIteration) so we
        # have to format them.
        file.processing_error = str(error) or str(
            traceback.format_exception(error.__class__, error,
                                       error.__traceback__))
        if not isinstance(error, UserError):
            raise error
        return True
Esempio n. 7
0
class MetadataApi:
    def __init__(self, env_id: str, storage: Storage, initialize: bool = False):
        self.env_id = env_id
        self.storage = storage
        self.engine = self.storage.get_api().get_engine()
        self.active_session = None
        # self.Session = sessionmaker(self.engine)
        if initialize:
            self.initialize_metadata_database()

    @contextmanager
    def begin(self) -> Iterator[Session]:
        try:
            if self.active_session is None:
                self.active_session = Session(self.engine)
            yield self.active_session
        finally:
            self.active_session.commit()
            # self.active_session.close()
            # self.active_session = None

    ensure_session = begin

    @contextmanager
    def begin_nested(self) -> Iterator[SessionTransaction]:
        assert self.active_session is not None
        with self.active_session.begin_nested() as sess_tx:
            yield sess_tx

    # @contextmanager
    # def ensure_session(self) -> Iterator[Session]:
    #     if self.active_session is None:
    #         with self.Session.begin() as sess:
    #             self.active_session = sess
    #             yield sess
    #         self.active_session = None
    #     else:
    #         yield self.active_session

    def get_session(self) -> Session:
        if self.active_session is None:
            raise ValueError(
                "No metadata session active. Call MetadataApi.begin() beforehand"
            )
        return self.active_session

    def augment_statement(
        self, stmt: Union[Select, Update, Delete], filter_env: bool = True
    ) -> Select:
        if filter_env:
            stmt = stmt.filter_by(env_id=self.env_id)
        return stmt

    def execute(
        self, stmt: Union[Select, Update, Delete], filter_env: bool = True
    ) -> Result:
        stmt = self.augment_statement(stmt, filter_env=filter_env)
        return self.get_session().execute(stmt)

    def count(self, stmt: Select, filter_env: bool = True) -> int:
        stmt = select(func.count()).select_from(stmt.subquery())
        return self.execute(stmt).scalar_one()

    def add(self, obj: Any, set_env: bool = True):
        if obj.env_id is None and set_env:
            obj.env_id = self.env_id
        self.get_session().add(obj)

    def add_all(self, objects: Iterable, set_env: bool = True):
        for obj in objects:
            if obj.env_id is None and set_env:
                obj.env_id = self.env_id
        self.get_session().add_all(objects)

    def flush(self, objects=None):
        if objects:
            self.get_session().flush(objects)
        else:
            self.get_session().flush()

    def delete(self, obj):
        sess = self.get_session()
        if obj in sess.new:
            sess.expunge(obj)
        else:
            sess.delete(obj)

    def commit(self):
        self.get_session().commit()

    ### Alembic

    def initialize_metadata_database(self):
        if not issubclass(
            self.storage.storage_engine.storage_class, DatabaseStorageClass
        ):
            raise ValueError(
                f"metadata storage expected a database, got {self.storage}"
            )
        # BaseModel.metadata.create_all(conn)
        # try:
        self.migrate_metdata_database()
        # except SQLAlchemyError as e:
        #     # Catch database exception, meaning already created, just stamp
        #     # For initial migration
        #     # TODO: remove once all 0.2 systems migrated?
        #     logger.warning(e)
        #     self.stamp_metadata_database()
        #     self.migrate_metdata_database()

    def migrate_metdata_database(self):
        alembic_cfg = self.get_alembic_config()
        if self.engine is not None:
            alembic_cfg.attributes["connection"] = self.engine
        command.upgrade(alembic_cfg, "head")

    def get_alembic_config(self) -> Config:
        dir_path = pathlib.Path(__file__).parent.absolute()
        cfg_path = dir_path / "../../migrations/alembic.ini"
        alembic_cfg = Config(str(cfg_path))
        alembic_cfg.set_main_option("sqlalchemy.url", self.storage.url)
        return alembic_cfg

    def stamp_metadata_database(self):
        alembic_cfg = self.get_alembic_config()
        if self.engine is not None:
            alembic_cfg.attributes["connection"] = self.engine
        command.stamp(alembic_cfg, "23dd1cc88eb2")