Ejemplo n.º 1
0
async def sess(async_engine: AsyncEngine) -> AsyncGenerator[Session, None]:

    Session = sessionmaker(  # type: ignore
        bind=async_engine,
        class_=AsyncSession,
        expire_on_commit=False,
    )

    async with async_engine.begin() as conn:

        # Bind a session to the top level transaction
        _session = Session(bind=conn)

        # Start a savepoint that we can rollback to in the transaction
        _session.begin_nested()

        @sqlalchemy.event.listens_for(_session.sync_session,
                                      "after_transaction_end")
        def restart_savepoint(sess, trans):  # type: ignore
            """Register event listener to clean up the sqla objects of a session after a transaction ends"""
            if trans.nested and not trans._parent.nested:
                # Expire all objects registered against the session
                sess.expire_all()
                sess.begin_nested()

            yield _session

        yield _session

        # Close the session object
        await _session.close()

        # Rollback to the savepoint, eliminating everything that happend to the _session
        await conn.rollback()  # type: ignore
Ejemplo n.º 2
0
def db_session(app):
    conn = db.engine.connect()
    trans = conn.begin()

    session = Session(bind=conn)
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @sa.event.listens_for(db.session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.expire_all()
            session.begin_nested()

    db.session.begin_nested()

    UsuarioFactory._meta.sqlalchemy_session = db.session
    EmpresaFactory._meta.sqlalchemy_session = db.session
    ProgramaFactory._meta.sqlalchemy_session = db.session
    CartaoFactory._meta.sqlalchemy_session = db.session

    yield db.session

    # rollback everything
    trans.rollback()
    conn.close()
    db.session.remove()
Ejemplo n.º 3
0
def db(app, pytestconfig):
    # https://docs.sqlalchemy.org/en/13/orm/session_transaction.html
    # https://gist.github.com/zzzeek/8443477
    if pytestconfig.option.use_isolation:
        connection = app.database.engine.connect()
        transaction = connection.begin()
        session = Session(bind=connection)

        # start the session in a SAVEPOINT...
        # start the session in a SAVEPOINT...
        session.begin_nested()

        # then each time that SAVEPOINT ends, reopen it
        @event.listens_for(session, "after_transaction_end")
        def restart_savepoint(session, transaction):
            if transaction.nested and not transaction._parent.nested:

                # ensure that state is expired the way
                # session.commit() at the top level normally does
                # (optional step)
                session.expire_all()
                session.begin_nested()

        app.database.session = session
        _setup_context(app)

        yield app.database
        session.close()
        transaction.rollback()
        connection.close()
    else:
        yield app.database
Ejemplo n.º 4
0
    def test_report_primary_error_when_rollback_fails(self):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        session = Session(testing.db)

        with expect_warnings(".*during handling of a previous exception.*"):
            session.begin_nested()
            savepoint = session.\
                connection()._Connection__transaction._savepoint

            # force the savepoint to disappear
            session.connection().dialect.do_release_savepoint(
                session.connection(), savepoint
            )

            # now do a broken flush
            session.add_all([User(id=1), User(id=1)])

            assert_raises_message(
                sa_exc.DBAPIError,
                "ROLLBACK TO SAVEPOINT ",
                session.flush
            )
Ejemplo n.º 5
0
    def test_report_primary_error_when_rollback_fails(self):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        session = Session(testing.db)

        with expect_warnings(".*during handling of a previous exception.*"):
            session.begin_nested()
            savepoint = session.\
                connection()._Connection__transaction._savepoint

            # force the savepoint to disappear
            session.connection().dialect.do_release_savepoint(
                session.connection(), savepoint
            )

            # now do a broken flush
            session.add_all([User(id=1), User(id=1)])

            assert_raises_message(
                sa_exc.DBAPIError,
                "ROLLBACK TO SAVEPOINT ",
                session.flush
            )
Ejemplo n.º 6
0
    def restart_savepoint(session: Session, transaction: SessionTransaction):
        if transaction.nested and not transaction._parent.nested:

            # ensure that state is expired the way
            # session.commit() at the top level normally does
            # (optional step)
            session.expire_all()
            session.begin_nested()
Ejemplo n.º 7
0
def orphan_header_chain(session: orm.Session,
                        orphans: Sequence[HeaderIR]) -> None:
    with session.begin_nested():
        header_hashes = set(header_ir.hash for header_ir in orphans)

        # Set all the now orphaned headers as being non-canonical
        session.query(Header).filter(  # type: ignore
            Header.hash.in_(header_hashes)).update({"is_canonical": False},
                                                   synchronize_session=False)

        # Unlink each transaction from the block.  We query across the
        # `BlockTransaction` join table because the
        # `Transaction.block_header_hash` field may have already been set to
        # null in the case that this transaction has already been part of
        # another re-org.

        # We can't perform an `.update()` call if we do this with a join so first
        # we pull the transaction hashes above and then we execute the update.
        transactions = (
            session.query(Transaction)  # type: ignore
            .join(
                BlockTransaction,
                Transaction.hash == BlockTransaction.transaction_hash).filter(
                    or_(
                        BlockTransaction.block_header_hash.in_(header_hashes),
                        Transaction.block_header_hash.in_(header_hashes),
                    )).all())

        if not transactions:
            logger.debug("No orphaned transactions to unlink....")

        for transaction in transactions:
            logger.debug(
                "Unlinking txn: %s from block %s",
                humanize_hash(transaction.hash),
                humanize_hash(transaction.block_header_hash),
            )
            transaction.block_header_hash = None

            if transaction.receipt is not None:
                for log_idx, log in enumerate(transaction.receipt.logs):
                    logger.debug("Deleting log #%d", log_idx)
                    for logtopic in log.logtopics:
                        logger.debug(
                            "Deleting logtopic #%d: %s",
                            logtopic.idx,
                            humanize_hash(logtopic.topic_topic),
                        )
                        with session.begin_nested():
                            session.delete(logtopic)  # type: ignore
                    with session.begin_nested():
                        session.delete(log)  # type: ignore
                logger.debug("Deleting txn receipt: %s",
                             humanize_hash(transaction.hash))
                with session.begin_nested():
                    session.delete(transaction.receipt)  # type: ignore
            else:
                logger.debug("Txn %s already has null receipt")
Ejemplo n.º 8
0
def accept_mission_for_volunteer(db: Session, volunteer: models.Volunteer,
                                 mission: models.Mission):
    if not check_mission_not_accepted(db, mission):
        raise MissionAlreadyAccepted(
            f"Mission {mission.id} was already accepted by another volunteer")
    try:
        with db.begin_nested():
            db.execute(
                f'LOCK TABLE {models.Mission.__tablename__} IN EXCLUSIVE MODE;'
            )
            db.execute(
                f'LOCK TABLE {models.VolunteerMission.__tablename__} IN EXCLUSIVE MODE;'
            )
            set_volunteer_mission_state(db, volunteer, mission,
                                        models.VolunteerMissionState.accepted)
            set_mission_state(db,
                              mission,
                              models.MissionState.acquired,
                              commit=False)
        logger.info(
            f"Set mission {mission.uuid} to acquired and missionVolunteer {volunteer.uuid} to accepted"
        )
        # TODO: Add here task to check what's going on with the missions
    except Exception as e:
        logger.error(
            f"Failed to accept mission. mission:{mission.uuid}, volunteer:{volunteer.uuid}"
        )
        db.rollback()
        raise e
    finally:
        # Double commit one for the db.begin_nested() and one for the global db transaction
        db.commit()
        db.commit()
Ejemplo n.º 9
0
def modify_user_book_ignore_status(db: Session,
                                   user: Union[str, external.NewUser] = None,
                                   book: Union[str, external.NewBook] = None,
                                   ubl: external.NewUserBook = None,
                                   ignored: bool = True) -> NoReturn:
    """
    User book ignore status helper
    Needs to know the UBL model or user and book pair

    Basically a soft delete
    :param ubl:         UserBook model
    :param user:        User
    :param book:        Book
    :param db:          ORM Session
    :param ignored:     Ignored status (default: True)
    """
    b = _get_book((book if isinstance(book, str) else book.handle)
                  if ubl is None else ubl.handle, db)
    u = _get_user((user if isinstance(user, str) else user.username)
                  if ubl is None else ubl.user, db)
    with db.begin_nested():
        try:
            r: internal.UserBook = db.query(internal.UserBook).where(
                internal.UserBook.user_id == u.id).where(
                    internal.UserBook.book_id == b.id).one()
            r.ignored = ignored
        except NoResultFound:
            raise NotFound(f"User {u.username} has no record for {b.handle}")
Ejemplo n.º 10
0
    def get_or_create(cls, session: Session, **kwargs) -> tuple:
        """Gets a record or creates if it does not yet exist.

        Args:
            session (Session object): current ORM session
            **kwargs: fields and values for filtering

        Returns:
            (obj, bool) where:
                obj: the created/retrieved instance
                bool: True if obj is a new object, False otherwise
        """
        is_new_instance = False
        instance = session.query(cls).filter_by(**kwargs).first()

        if not instance:
            try:
                with session.begin_nested():
                    instance = cls(**kwargs)
                session.add(instance)
                session.flush()
            except IntegrityError:
                session.rollback()
            else:
                is_new_instance = True

        return instance, is_new_instance
Ejemplo n.º 11
0
    def test_dirty_state_transferred_deep_nesting(self):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        s = Session(testing.db)
        u1 = User(name='u1')
        s.add(u1)
        s.commit()

        nt1 = s.begin_nested()
        nt2 = s.begin_nested()
        u1.name = 'u2'
        assert attributes.instance_state(u1) not in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty
        s.flush()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty

        s.commit()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) in nt1._dirty

        s.rollback()
        assert attributes.instance_state(u1).expired
        eq_(u1.name, 'u1')
Ejemplo n.º 12
0
    def test_dirty_state_transferred_deep_nesting(self):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        s = Session(testing.db)
        u1 = User(name='u1')
        s.add(u1)
        s.commit()

        nt1 = s.begin_nested()
        nt2 = s.begin_nested()
        u1.name = 'u2'
        assert attributes.instance_state(u1) not in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty
        s.flush()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty

        s.commit()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) in nt1._dirty

        s.rollback()
        assert attributes.instance_state(u1).expired
        eq_(u1.name, 'u1')
Ejemplo n.º 13
0
def delete_review(review: Union[external.NewReview,
                                Tuple[Union[str, external.NewBook],
                                      Union[str, external.NewUser]]],
                  db: Session,
                  hard: bool = False) -> NoReturn:
    """
    Soft deletes a review

    :param hard:    Hard Delete
    :param review:  External review model or a Tuple(Book, User)
    :param db:      ORM Session
    """
    r: internal.Review
    if isinstance(review, tuple):
        up = review[1]
        u: internal.User
        if isinstance(up, str):
            u = _get_user(up, db, throw=not hard)
        else:
            u = _get_user(up.username, db, throw=not hard)
        bp = review[0]
        b: internal.Book
        if isinstance(bp, str):
            b = _get_book(bp, db)
        else:
            b = _get_book(bp.handle, db)
        r = _get_review(u.username, b.handle, db)
        with db.begin_nested() as nested:
            __delete(r, nested, hard)
Ejemplo n.º 14
0
        def go():
            session = Session(testing.db)
            with session.transaction:

                sc = SomeClass()
                session.add(sc)
                with session.begin_nested():
                    session.query(SomeClass).first()
Ejemplo n.º 15
0
def store_user_book(model: external.NewUserBook,
                    db: Session,
                    overwrite: bool = False) -> external.UserBook:
    """
    Function for storing user book records

    Mainly intended for creating and updating user book records.

    :param model:       UBL model
    :param db:          ORM Session
    :param overwrite:   Whether to overwrite (update) existing records
    :return:            User book instance
    """
    u = _get_user(model.user, db)
    b = _get_book(model.handle, db)

    existing = db.query(
        internal.UserBook).where(internal.UserBook.user_id == u.id).where(
            internal.UserBook.book_id == b.id).first()

    d = model.dict(exclude_none=True, exclude={'user', 'handle'})
    with db.begin_nested():
        new_record: internal.UserBook
        if existing and not overwrite:
            raise AlreadyExists(
                f"User {u.username} already has a record for {b.handle}")
        else:
            if existing:
                new_record = existing
                for k in d:
                    setattr(new_record, k, d[k])
            else:
                new_record = internal.UserBook(user_id=u.id, book_id=b.id, **d)
        if not existing:
            with db.begin_nested():
                db.add(new_record)
        db.refresh(new_record)
        # Its not dumb if it works, right?
        return external.UserBook(
            **external.UserBookInternalBase.from_orm(new_record).dict(
                exclude_none=True),
            **external.Book.from_orm(new_record.book).dict(exclude_none=True),
            user=new_record.user.username,
        )
Ejemplo n.º 16
0
def test_db(db: Session):
    from bookclub.data.model.db_models import User
    u1 = User(username="******")
    u2 = User(username="******")
    db.add(u1)
    db.flush()
    with pytest.raises(IntegrityError):
        with db.begin_nested():
            db.add(u2)
            db.flush()
Ejemplo n.º 17
0
    def _run_test(self, update_fn):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        s = Session(bind=testing.db)
        u1 = User(name='u1')
        u2 = User(name='u2')
        s.add_all([u1, u2])
        s.commit()
        u1.name
        u2.name
        s.begin_nested()
        update_fn(s, u2)
        eq_(u2.name, 'u2modified')
        s.rollback()
        eq_(u1.__dict__['name'], 'u1')
        assert 'name' not in u2.__dict__
        eq_(u2.name, 'u2')
Ejemplo n.º 18
0
    def _run_test(self, update_fn):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        s = Session(bind=testing.db)
        u1 = User(name='u1')
        u2 = User(name='u2')
        s.add_all([u1, u2])
        s.commit()
        u1.name
        u2.name
        s.begin_nested()
        update_fn(s, u2)
        eq_(u2.name, 'u2modified')
        s.rollback()
        eq_(u1.__dict__['name'], 'u1')
        assert 'name' not in u2.__dict__
        eq_(u2.name, 'u2')
Ejemplo n.º 19
0
def build_block_chain(
    session: orm.Session,
    topic_factory: ThingGenerator[Hash32],
    address_factory: ThingGenerator[Address],
    num_blocks: int,
) -> Iterator[Header]:
    for block_number in range(num_blocks):
        with session.begin_nested():
            if block_number == 0:
                parent_hash = None
            else:
                parent = (session.query(Header).filter(
                    Header.block_number == block_number - 1,
                    Header.is_canonical.is_(True),
                ).one())
                parent_hash = parent.hash

            header = HeaderFactory(block_number=block_number,
                                   _parent_hash=parent_hash)
            block = BlockFactory(header=header)

            num_transactions = int(random.expovariate(0.1))

            transactions = tuple(
                TransactionFactory(block=block)
                for _ in range(num_transactions))
            blocktransactions = tuple(
                BlockTransactionFactory(
                    idx=idx, block=block, transaction=transaction)
                for idx, transaction in enumerate(transactions))
            receipts = tuple(
                ReceiptFactory(transaction=transaction)
                for transaction in transactions)
            num_logs_per_transaction = tuple(
                int(random.expovariate(0.1)) for transaction in transactions)
            log_bundles = tuple(
                build_log(session, idx, receipt, topic_factory,
                          address_factory) for num_logs, receipt in zip(
                              num_logs_per_transaction, receipts)
                for idx in range(num_logs))
            if log_bundles:
                logs, logtopic_bundles = zip(*log_bundles)
                logtopics = tuple(itertools.chain(*logtopic_bundles))
            else:
                logs, logtopics = (), ()

            session.add(header)
            session.add(block)
            session.add_all(transactions)
            session.add_all(blocktransactions)
            session.add_all(receipts)
            session.add_all(logs)
            session.add_all(logtopics)

            yield header
Ejemplo n.º 20
0
def construct_log(
    session: orm.Session,
    *,
    block_number: Optional[BlockNumber] = None,
    address: Optional[Address] = None,
    topics: Sequence[Hash32] = (),
    data: bytes = b"",
    is_canonical: bool = True,
) -> Log:
    with session.begin_nested():
        if block_number is not None:
            try:
                header = (
                    session.query(Header)  # type: ignore
                    .filter(Header.is_canonical.is_(is_canonical))  # type: ignore
                    .filter(Header.block_number == block_number)
                    .one()
                )
            except NoResultFound:
                header = HeaderFactory(
                    is_canonical=is_canonical, block_number=block_number
                )
        else:
            header = HeaderFactory(is_canonical=is_canonical)

        if address is None:
            address = AddressFactory()

        session.add(header)

        topic_objs = _get_or_create_topics(session, topics)

        session.add_all(topic_objs)  # type: ignore

        log = LogFactory(
            receipt__blocktransaction__block__header=header, address=address, data=data
        )

        log_topics = tuple(
            LogTopic(
                idx=idx,
                topic_topic=topic,
                log_idx=log.idx,
                log_transaction_hash=log.receipt.transaction_hash,
                log_block_header_hash=log.receipt.block_header_hash,
            )
            for idx, topic in enumerate(topics)
        )

        session.add(log)
        session.add_all(log_topics)  # type: ignore

    session.refresh(log)

    return log
Ejemplo n.º 21
0
def delete_comment(comment: Union[int, external.Comment],
                   db: Session,
                   hard: bool = False) -> NoReturn:
    """
    Soft delete a comment

    :param hard:    Hard Delete
    :param comment: Comment (uuid or instance)
    :param db:      ORM Session
    """
    c = _get_comment(comment.uuid, db)
    with db.begin_nested() as nested:
        __delete(c, nested, hard)
Ejemplo n.º 22
0
def get_db_typer():
    session = Session(sync_engine)
    try:
        with session.begin_nested():
            yield session
    except PendingRollbackError:
        pass
    finally:
        try:
            session.commit()
        except PendingRollbackError:
            pass
        session.close()
Ejemplo n.º 23
0
class TestConfig:
    def setUp(self):
        self.engine = create_engine('sqlite:///test.db', echo=True)
        Session = sessionmaker(bind=self.engine)
        self.session = Session()
        self._load_fixture()
        self.session.begin_nested()

    def _store_fixture_data(self, model, items):
        try:
            table = model.__table__
            self.session.execute(table.insert().values(items))
            self.session.commit()
        except Exception as e:
            #logger.error(e, exc_info=True)
            pass

    def _load_fixture(self):
        self._store_fixture_data(Todo, todos)

    def tearDown(self):
        self.session.rollback()
        self.session.close()
Ejemplo n.º 24
0
def _modify(i: T, d: Dict[str, Any], db: Session) -> bool:
    """
    Modify orm instance in a transaction

    :param i:   Instance to modify
    :param d:   Dict for field references
    :param db:  ORM Session
    :return:    Whether the object changed as a result
    """
    with db.begin_nested():
        changed = False
        for k in d:
            if getattr(i, k) != d[k]:
                changed = True
                setattr(i, k, d[k])
        return changed
Ejemplo n.º 25
0
def delete_club(club: Union[str, external.NewClub],
                db: Session,
                hard: bool = False) -> NoReturn:
    """
    Soft delete a club

    :param hard:    Hard Delete
    :param club:    Club
    :param db:      ORM Session
    """
    c: internal.Club
    if isinstance(club, str):
        c = _get_club(club, db, throw=not hard)
    else:
        c = _get_club(club.handle, db, throw=not hard)
    with db.begin_nested() as nested:
        __delete(c, nested, hard)
Ejemplo n.º 26
0
def simulate_entity(
    sess: Session,
    entity: "ReplaceableEntity",
    dependencies: Optional[List["ReplaceableEntity"]] = None,
):
    """Creates *entiity* in a transaction so postgres rendered definition
    can be retrieved
    """

    # When simulating materialized view, don't populate them with data
    from alembic_utils.pg_materialized_view import PGMaterializedView

    if isinstance(entity, PGMaterializedView) and entity.with_data:
        entity = copy.deepcopy(entity)
        entity.with_data = False

    deps: List["ReplaceableEntity"] = dependencies or []

    try:
        sess.begin_nested()

        dependency_managers = [simulate_entity(sess, x) for x in deps]

        with ExitStack() as stack:
            # Setup all the possible deps
            for mgr in dependency_managers:
                stack.enter_context(mgr)

            did_drop = False
            try:
                sess.begin_nested()
                sess.execute(entity.to_sql_statement_drop(cascade=True))
                did_drop = True
                sess.execute(entity.to_sql_statement_create())
                yield sess
            except:
                if did_drop:
                    # The drop was successful, so either create was not, or the
                    # error came from user code after the yield.
                    # Anyway, we can exit now.
                    raise

                # Try again without the drop in case the drop raised
                # a does not exist error
                sess.rollback()
                sess.begin_nested()
                sess.execute(entity.to_sql_statement_create())
                yield sess
            finally:
                sess.rollback()
    finally:
        sess.rollback()
Ejemplo n.º 27
0
def delete_book(book: Union[str, external.NewBook],
                db: Session,
                hard: bool = False) -> NoReturn:
    """
    Soft delete a book

    :param hard:    Hard Delete
    :param book:    Book (handle or instance)
    :param db:      ORM Session
    """
    handle: str
    if isinstance(book, str):
        handle = book
    else:
        handle = book.handle
    b = _get_book(handle, db, throw=not hard)
    with db.begin_nested() as nested:
        __delete(b, nested, hard)
Ejemplo n.º 28
0
def delete_user(user: Union[str, external.NewUser],
                db: Session,
                hard: bool = False) -> NoReturn:
    """
    Soft delete a user

    :param hard:    Hard Delete
    :param user:    User (username or model)
    :param db:      ORM Session
    """
    if user is not None:
        username: str
        if isinstance(user, str):
            username = user
        else:
            username = user.username
        u = _get_user(username, db, throw=not hard)
        with db.begin_nested() as nested:
            __delete(u, nested, hard)
Ejemplo n.º 29
0
def noisey_get_one_or_create(session: Session, model: BaseType,
                             **kwargs) -> Base:
    """
    Get an instance of `model` from the database if it exists or create it
    """
    params = _prepare_model_params(session, noisey_get_one_or_create, **kwargs)
    try:
        # Perform and initial query to find the instance
        return session.query(model).filter_by(**params).one()
    except NoResultFound:
        instance = model(**params)
        try:
            # Handle rolling back just this query if the instance has been added
            # before we get change to create it.
            with session.begin_nested():
                session.add(instance)
                return instance
        except IntegrityError:
            return session.query(model).filter_by(**params).one()
Ejemplo n.º 30
0
    def test_contextmanager_nested_rollback(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)

        sess = Session()

        def go():
            with sess.begin_nested():
                sess.add(User())  # name can't be null
                sess.flush()

        # and not InvalidRequestError
        assert_raises(sa_exc.DBAPIError, go)

        with sess.begin_nested():
            sess.add(User(name='u1'))

        eq_(sess.query(User).count(), 1)
Ejemplo n.º 31
0
def _add(e: Any,
         i: Type[T],
         db: Session,
         exclude: Set[str] = None,
         extra: Dict[str, Any] = None) -> T:
    """
    Add an ORM entity to session

    :param e:   External entity model
    :param i:   Internal entity model
    :param db:  ORM Session
    :return:    Created internal entity
    """
    if extra is None:
        extra = {}
    with db.begin_nested() as nested:
        no = i(**e.dict(exclude_none=True, exclude=exclude),
               **extra,
               deleted=False)
        nested.session.add(no)
        return no
Ejemplo n.º 32
0
    def test_contextmanager_nested_rollback(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)

        sess = Session()

        def go():
            with sess.begin_nested():
                sess.add(User())   # name can't be null
                sess.flush()

        # and not InvalidRequestError
        assert_raises(
            sa_exc.DBAPIError,
            go
        )

        with sess.begin_nested():
            sess.add(User(name='u1'))

        eq_(sess.query(User).count(), 1)
Ejemplo n.º 33
0
def create_user(session: Session, user_data: UserInputType,
                user: User) -> http.JSONResponse:
    txn = session.begin_nested()

    new_user = User(**dict(user_data))

    if new_user.id is not None:
        raise BadRequest({'error': 'user ID cannot be set'})

    if not can_user_create_user(user, new_user):
        msg = 'user cannot create user with role "{}"'
        raise BadRequest({'error': msg.format(new_user.role.name)})

    session.add(new_user)

    try:
        txn.commit()
    except IntegrityError:
        txn.rollback()
        raise BadRequest({'error': 'username already exists'})

    return http.JSONResponse(UserType(new_user), status_code=201)
Ejemplo n.º 34
0
    def resolve(
            self, request: http.Request, session: Session,
            token: http.QueryParam
        # pylint: disable=arguments-differ
    ) -> User:
        if token:
            try:
                token = uuid.UUID(token)
            except ValueError:
                return None

            user = session.query(User) \
                .join(Token) \
                .filter(Token.id == token).first()
            return user

        session_id = self.get_session_id(request.headers)
        if session_id:
            with session.begin_nested():
                return self.resolve_with_session(session, session_id)

        return None