async def sess(async_engine: AsyncEngine) -> AsyncGenerator[Session, None]: Session = sessionmaker( # type: ignore bind=async_engine, class_=AsyncSession, expire_on_commit=False, ) async with async_engine.begin() as conn: # Bind a session to the top level transaction _session = Session(bind=conn) # Start a savepoint that we can rollback to in the transaction _session.begin_nested() @sqlalchemy.event.listens_for(_session.sync_session, "after_transaction_end") def restart_savepoint(sess, trans): # type: ignore """Register event listener to clean up the sqla objects of a session after a transaction ends""" if trans.nested and not trans._parent.nested: # Expire all objects registered against the session sess.expire_all() sess.begin_nested() yield _session yield _session # Close the session object await _session.close() # Rollback to the savepoint, eliminating everything that happend to the _session await conn.rollback() # type: ignore
def db_session(app): conn = db.engine.connect() trans = conn.begin() session = Session(bind=conn) session.begin_nested() # then each time that SAVEPOINT ends, reopen it @sa.event.listens_for(db.session, "after_transaction_end") def restart_savepoint(session, transaction): if transaction.nested and not transaction._parent.nested: session.expire_all() session.begin_nested() db.session.begin_nested() UsuarioFactory._meta.sqlalchemy_session = db.session EmpresaFactory._meta.sqlalchemy_session = db.session ProgramaFactory._meta.sqlalchemy_session = db.session CartaoFactory._meta.sqlalchemy_session = db.session yield db.session # rollback everything trans.rollback() conn.close() db.session.remove()
def db(app, pytestconfig): # https://docs.sqlalchemy.org/en/13/orm/session_transaction.html # https://gist.github.com/zzzeek/8443477 if pytestconfig.option.use_isolation: connection = app.database.engine.connect() transaction = connection.begin() session = Session(bind=connection) # start the session in a SAVEPOINT... # start the session in a SAVEPOINT... session.begin_nested() # then each time that SAVEPOINT ends, reopen it @event.listens_for(session, "after_transaction_end") def restart_savepoint(session, transaction): if transaction.nested and not transaction._parent.nested: # ensure that state is expired the way # session.commit() at the top level normally does # (optional step) session.expire_all() session.begin_nested() app.database.session = session _setup_context(app) yield app.database session.close() transaction.rollback() connection.close() else: yield app.database
def test_report_primary_error_when_rollback_fails(self): User, users = self.classes.User, self.tables.users mapper(User, users) session = Session(testing.db) with expect_warnings(".*during handling of a previous exception.*"): session.begin_nested() savepoint = session.\ connection()._Connection__transaction._savepoint # force the savepoint to disappear session.connection().dialect.do_release_savepoint( session.connection(), savepoint ) # now do a broken flush session.add_all([User(id=1), User(id=1)]) assert_raises_message( sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush )
def restart_savepoint(session: Session, transaction: SessionTransaction): if transaction.nested and not transaction._parent.nested: # ensure that state is expired the way # session.commit() at the top level normally does # (optional step) session.expire_all() session.begin_nested()
def orphan_header_chain(session: orm.Session, orphans: Sequence[HeaderIR]) -> None: with session.begin_nested(): header_hashes = set(header_ir.hash for header_ir in orphans) # Set all the now orphaned headers as being non-canonical session.query(Header).filter( # type: ignore Header.hash.in_(header_hashes)).update({"is_canonical": False}, synchronize_session=False) # Unlink each transaction from the block. We query across the # `BlockTransaction` join table because the # `Transaction.block_header_hash` field may have already been set to # null in the case that this transaction has already been part of # another re-org. # We can't perform an `.update()` call if we do this with a join so first # we pull the transaction hashes above and then we execute the update. transactions = ( session.query(Transaction) # type: ignore .join( BlockTransaction, Transaction.hash == BlockTransaction.transaction_hash).filter( or_( BlockTransaction.block_header_hash.in_(header_hashes), Transaction.block_header_hash.in_(header_hashes), )).all()) if not transactions: logger.debug("No orphaned transactions to unlink....") for transaction in transactions: logger.debug( "Unlinking txn: %s from block %s", humanize_hash(transaction.hash), humanize_hash(transaction.block_header_hash), ) transaction.block_header_hash = None if transaction.receipt is not None: for log_idx, log in enumerate(transaction.receipt.logs): logger.debug("Deleting log #%d", log_idx) for logtopic in log.logtopics: logger.debug( "Deleting logtopic #%d: %s", logtopic.idx, humanize_hash(logtopic.topic_topic), ) with session.begin_nested(): session.delete(logtopic) # type: ignore with session.begin_nested(): session.delete(log) # type: ignore logger.debug("Deleting txn receipt: %s", humanize_hash(transaction.hash)) with session.begin_nested(): session.delete(transaction.receipt) # type: ignore else: logger.debug("Txn %s already has null receipt")
def accept_mission_for_volunteer(db: Session, volunteer: models.Volunteer, mission: models.Mission): if not check_mission_not_accepted(db, mission): raise MissionAlreadyAccepted( f"Mission {mission.id} was already accepted by another volunteer") try: with db.begin_nested(): db.execute( f'LOCK TABLE {models.Mission.__tablename__} IN EXCLUSIVE MODE;' ) db.execute( f'LOCK TABLE {models.VolunteerMission.__tablename__} IN EXCLUSIVE MODE;' ) set_volunteer_mission_state(db, volunteer, mission, models.VolunteerMissionState.accepted) set_mission_state(db, mission, models.MissionState.acquired, commit=False) logger.info( f"Set mission {mission.uuid} to acquired and missionVolunteer {volunteer.uuid} to accepted" ) # TODO: Add here task to check what's going on with the missions except Exception as e: logger.error( f"Failed to accept mission. mission:{mission.uuid}, volunteer:{volunteer.uuid}" ) db.rollback() raise e finally: # Double commit one for the db.begin_nested() and one for the global db transaction db.commit() db.commit()
def modify_user_book_ignore_status(db: Session, user: Union[str, external.NewUser] = None, book: Union[str, external.NewBook] = None, ubl: external.NewUserBook = None, ignored: bool = True) -> NoReturn: """ User book ignore status helper Needs to know the UBL model or user and book pair Basically a soft delete :param ubl: UserBook model :param user: User :param book: Book :param db: ORM Session :param ignored: Ignored status (default: True) """ b = _get_book((book if isinstance(book, str) else book.handle) if ubl is None else ubl.handle, db) u = _get_user((user if isinstance(user, str) else user.username) if ubl is None else ubl.user, db) with db.begin_nested(): try: r: internal.UserBook = db.query(internal.UserBook).where( internal.UserBook.user_id == u.id).where( internal.UserBook.book_id == b.id).one() r.ignored = ignored except NoResultFound: raise NotFound(f"User {u.username} has no record for {b.handle}")
def get_or_create(cls, session: Session, **kwargs) -> tuple: """Gets a record or creates if it does not yet exist. Args: session (Session object): current ORM session **kwargs: fields and values for filtering Returns: (obj, bool) where: obj: the created/retrieved instance bool: True if obj is a new object, False otherwise """ is_new_instance = False instance = session.query(cls).filter_by(**kwargs).first() if not instance: try: with session.begin_nested(): instance = cls(**kwargs) session.add(instance) session.flush() except IntegrityError: session.rollback() else: is_new_instance = True return instance, is_new_instance
def test_dirty_state_transferred_deep_nesting(self): User, users = self.classes.User, self.tables.users mapper(User, users) s = Session(testing.db) u1 = User(name='u1') s.add(u1) s.commit() nt1 = s.begin_nested() nt2 = s.begin_nested() u1.name = 'u2' assert attributes.instance_state(u1) not in nt2._dirty assert attributes.instance_state(u1) not in nt1._dirty s.flush() assert attributes.instance_state(u1) in nt2._dirty assert attributes.instance_state(u1) not in nt1._dirty s.commit() assert attributes.instance_state(u1) in nt2._dirty assert attributes.instance_state(u1) in nt1._dirty s.rollback() assert attributes.instance_state(u1).expired eq_(u1.name, 'u1')
def delete_review(review: Union[external.NewReview, Tuple[Union[str, external.NewBook], Union[str, external.NewUser]]], db: Session, hard: bool = False) -> NoReturn: """ Soft deletes a review :param hard: Hard Delete :param review: External review model or a Tuple(Book, User) :param db: ORM Session """ r: internal.Review if isinstance(review, tuple): up = review[1] u: internal.User if isinstance(up, str): u = _get_user(up, db, throw=not hard) else: u = _get_user(up.username, db, throw=not hard) bp = review[0] b: internal.Book if isinstance(bp, str): b = _get_book(bp, db) else: b = _get_book(bp.handle, db) r = _get_review(u.username, b.handle, db) with db.begin_nested() as nested: __delete(r, nested, hard)
def go(): session = Session(testing.db) with session.transaction: sc = SomeClass() session.add(sc) with session.begin_nested(): session.query(SomeClass).first()
def store_user_book(model: external.NewUserBook, db: Session, overwrite: bool = False) -> external.UserBook: """ Function for storing user book records Mainly intended for creating and updating user book records. :param model: UBL model :param db: ORM Session :param overwrite: Whether to overwrite (update) existing records :return: User book instance """ u = _get_user(model.user, db) b = _get_book(model.handle, db) existing = db.query( internal.UserBook).where(internal.UserBook.user_id == u.id).where( internal.UserBook.book_id == b.id).first() d = model.dict(exclude_none=True, exclude={'user', 'handle'}) with db.begin_nested(): new_record: internal.UserBook if existing and not overwrite: raise AlreadyExists( f"User {u.username} already has a record for {b.handle}") else: if existing: new_record = existing for k in d: setattr(new_record, k, d[k]) else: new_record = internal.UserBook(user_id=u.id, book_id=b.id, **d) if not existing: with db.begin_nested(): db.add(new_record) db.refresh(new_record) # Its not dumb if it works, right? return external.UserBook( **external.UserBookInternalBase.from_orm(new_record).dict( exclude_none=True), **external.Book.from_orm(new_record.book).dict(exclude_none=True), user=new_record.user.username, )
def test_db(db: Session): from bookclub.data.model.db_models import User u1 = User(username="******") u2 = User(username="******") db.add(u1) db.flush() with pytest.raises(IntegrityError): with db.begin_nested(): db.add(u2) db.flush()
def _run_test(self, update_fn): User, users = self.classes.User, self.tables.users mapper(User, users) s = Session(bind=testing.db) u1 = User(name='u1') u2 = User(name='u2') s.add_all([u1, u2]) s.commit() u1.name u2.name s.begin_nested() update_fn(s, u2) eq_(u2.name, 'u2modified') s.rollback() eq_(u1.__dict__['name'], 'u1') assert 'name' not in u2.__dict__ eq_(u2.name, 'u2')
def build_block_chain( session: orm.Session, topic_factory: ThingGenerator[Hash32], address_factory: ThingGenerator[Address], num_blocks: int, ) -> Iterator[Header]: for block_number in range(num_blocks): with session.begin_nested(): if block_number == 0: parent_hash = None else: parent = (session.query(Header).filter( Header.block_number == block_number - 1, Header.is_canonical.is_(True), ).one()) parent_hash = parent.hash header = HeaderFactory(block_number=block_number, _parent_hash=parent_hash) block = BlockFactory(header=header) num_transactions = int(random.expovariate(0.1)) transactions = tuple( TransactionFactory(block=block) for _ in range(num_transactions)) blocktransactions = tuple( BlockTransactionFactory( idx=idx, block=block, transaction=transaction) for idx, transaction in enumerate(transactions)) receipts = tuple( ReceiptFactory(transaction=transaction) for transaction in transactions) num_logs_per_transaction = tuple( int(random.expovariate(0.1)) for transaction in transactions) log_bundles = tuple( build_log(session, idx, receipt, topic_factory, address_factory) for num_logs, receipt in zip( num_logs_per_transaction, receipts) for idx in range(num_logs)) if log_bundles: logs, logtopic_bundles = zip(*log_bundles) logtopics = tuple(itertools.chain(*logtopic_bundles)) else: logs, logtopics = (), () session.add(header) session.add(block) session.add_all(transactions) session.add_all(blocktransactions) session.add_all(receipts) session.add_all(logs) session.add_all(logtopics) yield header
def construct_log( session: orm.Session, *, block_number: Optional[BlockNumber] = None, address: Optional[Address] = None, topics: Sequence[Hash32] = (), data: bytes = b"", is_canonical: bool = True, ) -> Log: with session.begin_nested(): if block_number is not None: try: header = ( session.query(Header) # type: ignore .filter(Header.is_canonical.is_(is_canonical)) # type: ignore .filter(Header.block_number == block_number) .one() ) except NoResultFound: header = HeaderFactory( is_canonical=is_canonical, block_number=block_number ) else: header = HeaderFactory(is_canonical=is_canonical) if address is None: address = AddressFactory() session.add(header) topic_objs = _get_or_create_topics(session, topics) session.add_all(topic_objs) # type: ignore log = LogFactory( receipt__blocktransaction__block__header=header, address=address, data=data ) log_topics = tuple( LogTopic( idx=idx, topic_topic=topic, log_idx=log.idx, log_transaction_hash=log.receipt.transaction_hash, log_block_header_hash=log.receipt.block_header_hash, ) for idx, topic in enumerate(topics) ) session.add(log) session.add_all(log_topics) # type: ignore session.refresh(log) return log
def delete_comment(comment: Union[int, external.Comment], db: Session, hard: bool = False) -> NoReturn: """ Soft delete a comment :param hard: Hard Delete :param comment: Comment (uuid or instance) :param db: ORM Session """ c = _get_comment(comment.uuid, db) with db.begin_nested() as nested: __delete(c, nested, hard)
def get_db_typer(): session = Session(sync_engine) try: with session.begin_nested(): yield session except PendingRollbackError: pass finally: try: session.commit() except PendingRollbackError: pass session.close()
class TestConfig: def setUp(self): self.engine = create_engine('sqlite:///test.db', echo=True) Session = sessionmaker(bind=self.engine) self.session = Session() self._load_fixture() self.session.begin_nested() def _store_fixture_data(self, model, items): try: table = model.__table__ self.session.execute(table.insert().values(items)) self.session.commit() except Exception as e: #logger.error(e, exc_info=True) pass def _load_fixture(self): self._store_fixture_data(Todo, todos) def tearDown(self): self.session.rollback() self.session.close()
def _modify(i: T, d: Dict[str, Any], db: Session) -> bool: """ Modify orm instance in a transaction :param i: Instance to modify :param d: Dict for field references :param db: ORM Session :return: Whether the object changed as a result """ with db.begin_nested(): changed = False for k in d: if getattr(i, k) != d[k]: changed = True setattr(i, k, d[k]) return changed
def delete_club(club: Union[str, external.NewClub], db: Session, hard: bool = False) -> NoReturn: """ Soft delete a club :param hard: Hard Delete :param club: Club :param db: ORM Session """ c: internal.Club if isinstance(club, str): c = _get_club(club, db, throw=not hard) else: c = _get_club(club.handle, db, throw=not hard) with db.begin_nested() as nested: __delete(c, nested, hard)
def simulate_entity( sess: Session, entity: "ReplaceableEntity", dependencies: Optional[List["ReplaceableEntity"]] = None, ): """Creates *entiity* in a transaction so postgres rendered definition can be retrieved """ # When simulating materialized view, don't populate them with data from alembic_utils.pg_materialized_view import PGMaterializedView if isinstance(entity, PGMaterializedView) and entity.with_data: entity = copy.deepcopy(entity) entity.with_data = False deps: List["ReplaceableEntity"] = dependencies or [] try: sess.begin_nested() dependency_managers = [simulate_entity(sess, x) for x in deps] with ExitStack() as stack: # Setup all the possible deps for mgr in dependency_managers: stack.enter_context(mgr) did_drop = False try: sess.begin_nested() sess.execute(entity.to_sql_statement_drop(cascade=True)) did_drop = True sess.execute(entity.to_sql_statement_create()) yield sess except: if did_drop: # The drop was successful, so either create was not, or the # error came from user code after the yield. # Anyway, we can exit now. raise # Try again without the drop in case the drop raised # a does not exist error sess.rollback() sess.begin_nested() sess.execute(entity.to_sql_statement_create()) yield sess finally: sess.rollback() finally: sess.rollback()
def delete_book(book: Union[str, external.NewBook], db: Session, hard: bool = False) -> NoReturn: """ Soft delete a book :param hard: Hard Delete :param book: Book (handle or instance) :param db: ORM Session """ handle: str if isinstance(book, str): handle = book else: handle = book.handle b = _get_book(handle, db, throw=not hard) with db.begin_nested() as nested: __delete(b, nested, hard)
def delete_user(user: Union[str, external.NewUser], db: Session, hard: bool = False) -> NoReturn: """ Soft delete a user :param hard: Hard Delete :param user: User (username or model) :param db: ORM Session """ if user is not None: username: str if isinstance(user, str): username = user else: username = user.username u = _get_user(username, db, throw=not hard) with db.begin_nested() as nested: __delete(u, nested, hard)
def noisey_get_one_or_create(session: Session, model: BaseType, **kwargs) -> Base: """ Get an instance of `model` from the database if it exists or create it """ params = _prepare_model_params(session, noisey_get_one_or_create, **kwargs) try: # Perform and initial query to find the instance return session.query(model).filter_by(**params).one() except NoResultFound: instance = model(**params) try: # Handle rolling back just this query if the instance has been added # before we get change to create it. with session.begin_nested(): session.add(instance) return instance except IntegrityError: return session.query(model).filter_by(**params).one()
def test_contextmanager_nested_rollback(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() def go(): with sess.begin_nested(): sess.add(User()) # name can't be null sess.flush() # and not InvalidRequestError assert_raises(sa_exc.DBAPIError, go) with sess.begin_nested(): sess.add(User(name='u1')) eq_(sess.query(User).count(), 1)
def _add(e: Any, i: Type[T], db: Session, exclude: Set[str] = None, extra: Dict[str, Any] = None) -> T: """ Add an ORM entity to session :param e: External entity model :param i: Internal entity model :param db: ORM Session :return: Created internal entity """ if extra is None: extra = {} with db.begin_nested() as nested: no = i(**e.dict(exclude_none=True, exclude=exclude), **extra, deleted=False) nested.session.add(no) return no
def test_contextmanager_nested_rollback(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() def go(): with sess.begin_nested(): sess.add(User()) # name can't be null sess.flush() # and not InvalidRequestError assert_raises( sa_exc.DBAPIError, go ) with sess.begin_nested(): sess.add(User(name='u1')) eq_(sess.query(User).count(), 1)
def create_user(session: Session, user_data: UserInputType, user: User) -> http.JSONResponse: txn = session.begin_nested() new_user = User(**dict(user_data)) if new_user.id is not None: raise BadRequest({'error': 'user ID cannot be set'}) if not can_user_create_user(user, new_user): msg = 'user cannot create user with role "{}"' raise BadRequest({'error': msg.format(new_user.role.name)}) session.add(new_user) try: txn.commit() except IntegrityError: txn.rollback() raise BadRequest({'error': 'username already exists'}) return http.JSONResponse(UserType(new_user), status_code=201)
def resolve( self, request: http.Request, session: Session, token: http.QueryParam # pylint: disable=arguments-differ ) -> User: if token: try: token = uuid.UUID(token) except ValueError: return None user = session.query(User) \ .join(Token) \ .filter(Token.id == token).first() return user session_id = self.get_session_id(request.headers) if session_id: with session.begin_nested(): return self.resolve_with_session(session, session_id) return None