def test_session_scope(self) -> Session: """Provides a transactional scope around a series of operations that will be rolled back at the end. Note: MySQL storage engine MyISAM does not support rollback transactions, so all the modifications performed to the database will persist. """ # Connect to the database connection = self.connect() # Begin a non-ORM transaction transaction = connection.begin() # Bind an individual Session to the connection session = Session(bind=connection) # Start the session in a SAVEPOINT session.begin_nested() # Define a new transaction event @event.listens_for(session, "after_transaction_end") def restart_savepoint(session, transaction): # pylint: disable=unused-variable """Reopen a SAVEPOINT whenever the previous one ends.""" if transaction.nested and not transaction._parent.nested: # pylint: disable=protected-access # Ensure that state is expired the same way session.commit() at the top level normally does session.expire_all() session.begin_nested() try: yield session finally: # Whatever happens, make sure the session and connection are closed, rolling back everything done # with the session (including calls to commit()) session.close() transaction.rollback() connection.close()
def db_session(test_db): connection = test_db.connect() transaction = connection.begin() session = Session(autocommit=False, autoflush=False, bind=connection) try: session.begin_nested() yield session finally: session.close() transaction.rollback() connection.close()
def new_trid(cls, session: Session, pid: int) -> int: """ We check for existence by inserting and asking the database if it's happy, not by asking the database if it exists (since other processes may be doing the same thing at the same time). """ while True: session.begin_nested() candidate = random.randint(1, MAX_TRID) log.debug("Trying candidate TRID: {}".format(candidate)) obj = cls(pid=pid, trid=candidate) try: session.add(obj) session.commit() # may raise IntegrityError return candidate except IntegrityError: session.rollback()
def setup_module(): global transaction, connection, engine engine = create_engine('postgresql:///recordsheet_test') connection = engine.connect() transaction = connection.begin() Base.metadata.create_all(connection) #insert some data inner_tr = connection.begin_nested() ses = Session(connection) # ses.begin_nested() ses.add(dbmodel.Account(name='TEST01', desc='test account 01')) ses.add(dbmodel.Account(name='TEST02', desc='test account 02')) user = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=False) ses.add(user) lockeduser = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=True) ses.add(lockeduser) batch = dbmodel.Batch(user=user) ses.add(batch) jrnl = dbmodel.Journal(memo='test', batch=batch, datetime='2016-06-05 14:09:00-05') ses.add(jrnl) ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1, journal=jrnl)) ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2, journal=jrnl)) ses.commit() # mock a sessionmaker so all querys are in this transaction dbapi._session = lambda: ses ses.begin_nested() @event.listens_for(ses, "after_transaction_end") def restart_savepoint(session, transaction): if transaction.nested and not transaction._parent.nested: ses.begin_nested()
def new_trid(cls, session: Session, pid: Union[int, str]) -> int: """ Creates a new TRID: a random integer that's not yet been used as a TRID. We check for existence by inserting and asking the database if it's happy, not by asking the database if it exists (since other processes may be doing the same thing at the same time). """ while True: session.begin_nested() candidate = random.randint(1, MAX_TRID) log.debug(f"Trying candidate TRID: {candidate}") # noinspection PyArgumentList obj = cls(pid=pid, trid=candidate) try: session.add(obj) session.commit() # may raise IntegrityError return candidate except IntegrityError: session.rollback()
def process_file(session: Session, file: File, callback: Callable[[], None]) -> bool: if file.processing_started_at: return False # Claim this file by updating the `processing_started_at` timestamp in such # a way that it must not have been set before. processing_started_at = datetime.datetime.now(timezone.utc) result = session.execute( update(File.__table__) # pylint: disable=no-member .where(File.id == file.id).where( File.processing_started_at.is_(None)).values( processing_started_at=processing_started_at)) if result.rowcount == 0: return False # If we got this far, `file` is ours to process. try: session.begin_nested() callback() file.processing_started_at = processing_started_at file.processing_completed_at = datetime.datetime.now(timezone.utc) session.add(file) session.commit() return True except Exception as error: session.rollback() file.processing_started_at = processing_started_at file.processing_completed_at = datetime.datetime.now(timezone.utc) # Some errors stringify nicely, some don't (e.g. StopIteration) so we # have to format them. file.processing_error = str(error) or str( traceback.format_exception(error.__class__, error, error.__traceback__)) if not isinstance(error, UserError): raise error return True
class MetadataApi: def __init__(self, env_id: str, storage: Storage, initialize: bool = False): self.env_id = env_id self.storage = storage self.engine = self.storage.get_api().get_engine() self.active_session = None # self.Session = sessionmaker(self.engine) if initialize: self.initialize_metadata_database() @contextmanager def begin(self) -> Iterator[Session]: try: if self.active_session is None: self.active_session = Session(self.engine) yield self.active_session finally: self.active_session.commit() # self.active_session.close() # self.active_session = None ensure_session = begin @contextmanager def begin_nested(self) -> Iterator[SessionTransaction]: assert self.active_session is not None with self.active_session.begin_nested() as sess_tx: yield sess_tx # @contextmanager # def ensure_session(self) -> Iterator[Session]: # if self.active_session is None: # with self.Session.begin() as sess: # self.active_session = sess # yield sess # self.active_session = None # else: # yield self.active_session def get_session(self) -> Session: if self.active_session is None: raise ValueError( "No metadata session active. Call MetadataApi.begin() beforehand" ) return self.active_session def augment_statement( self, stmt: Union[Select, Update, Delete], filter_env: bool = True ) -> Select: if filter_env: stmt = stmt.filter_by(env_id=self.env_id) return stmt def execute( self, stmt: Union[Select, Update, Delete], filter_env: bool = True ) -> Result: stmt = self.augment_statement(stmt, filter_env=filter_env) return self.get_session().execute(stmt) def count(self, stmt: Select, filter_env: bool = True) -> int: stmt = select(func.count()).select_from(stmt.subquery()) return self.execute(stmt).scalar_one() def add(self, obj: Any, set_env: bool = True): if obj.env_id is None and set_env: obj.env_id = self.env_id self.get_session().add(obj) def add_all(self, objects: Iterable, set_env: bool = True): for obj in objects: if obj.env_id is None and set_env: obj.env_id = self.env_id self.get_session().add_all(objects) def flush(self, objects=None): if objects: self.get_session().flush(objects) else: self.get_session().flush() def delete(self, obj): sess = self.get_session() if obj in sess.new: sess.expunge(obj) else: sess.delete(obj) def commit(self): self.get_session().commit() ### Alembic def initialize_metadata_database(self): if not issubclass( self.storage.storage_engine.storage_class, DatabaseStorageClass ): raise ValueError( f"metadata storage expected a database, got {self.storage}" ) # BaseModel.metadata.create_all(conn) # try: self.migrate_metdata_database() # except SQLAlchemyError as e: # # Catch database exception, meaning already created, just stamp # # For initial migration # # TODO: remove once all 0.2 systems migrated? # logger.warning(e) # self.stamp_metadata_database() # self.migrate_metdata_database() def migrate_metdata_database(self): alembic_cfg = self.get_alembic_config() if self.engine is not None: alembic_cfg.attributes["connection"] = self.engine command.upgrade(alembic_cfg, "head") def get_alembic_config(self) -> Config: dir_path = pathlib.Path(__file__).parent.absolute() cfg_path = dir_path / "../../migrations/alembic.ini" alembic_cfg = Config(str(cfg_path)) alembic_cfg.set_main_option("sqlalchemy.url", self.storage.url) return alembic_cfg def stamp_metadata_database(self): alembic_cfg = self.get_alembic_config() if self.engine is not None: alembic_cfg.attributes["connection"] = self.engine command.stamp(alembic_cfg, "23dd1cc88eb2")