def setup_method(self, method):
        # Mock the event registering of Flask-SQLAlchemy. Currently there is no
        # way of unregistering Flask-SQLAlchemy event listeners, hence the
        # event listeners would affect other tests.
        flexmock(_SessionSignalEvents).should_receive('register')

        self.db = SQLAlchemy()
        make_versioned()

        versioning_manager.transaction_cls = TransactionFactory()
        versioning_manager.options['native_versioning'] = (
            uses_native_versioning())

        self.create_models()

        sa.orm.configure_mappers()

        self.app = Flask(__name__)
        # self.app.config['SQLALCHEMY_ECHO'] = True
        self.app.config['SQLALCHEMY_DATABASE_URI'] = get_dns_from_driver(
            get_driver_name(os.environ.get('DB', 'sqlite')))
        self.db.init_app(self.app)
        self.app.secret_key = 'secret'
        self.app.debug = True
        self.client = self.app.test_client()
        self.context = self.app.test_request_context()
        self.context.push()
        self.db.create_all()
class TestCase(object):
    versioning_strategy = 'subquery'
    transaction_column_name = 'transaction_id'
    end_transaction_column_name = 'end_transaction_id'
    composite_pk = False
    plugins = [TransactionChangesPlugin(), TransactionMetaPlugin()]
    transaction_cls = TransactionFactory()
    user_cls = None
    should_create_models = True

    @property
    def options(self):
        return {
            'create_models': self.should_create_models,
            'native_versioning': uses_native_versioning(),
            'base_classes': (self.Model, ),
            'strategy': self.versioning_strategy,
            'transaction_column_name': self.transaction_column_name,
            'end_transaction_column_name': self.end_transaction_column_name,
        }

    def setup_method(self, method):
        self.Model = declarative_base()
        make_versioned(options=self.options)

        driver = os.environ.get('DB', 'sqlite')
        self.driver = get_driver_name(driver)
        versioning_manager.plugins = self.plugins
        versioning_manager.transaction_cls = self.transaction_cls
        versioning_manager.user_cls = self.user_cls

        self.engine = create_engine(get_dns_from_driver(self.driver))
        # self.engine.echo = True
        self.create_models()

        sa.orm.configure_mappers()

        self.connection = self.engine.connect()

        if hasattr(self, 'Article'):
            try:
                self.ArticleVersion = version_class(self.Article)
            except ClassNotVersioned:
                pass
        if hasattr(self, 'Tag'):
            try:
                self.TagVersion = version_class(self.Tag)
            except ClassNotVersioned:
                pass
        self.create_tables()

        Session = sessionmaker(bind=self.connection)
        self.session = Session(autoflush=False)
        if driver == 'postgres-native':
            self.session.execute('CREATE EXTENSION IF NOT EXISTS hstore')

    def create_tables(self):
        self.Model.metadata.create_all(self.connection)

    def drop_tables(self):
        self.Model.metadata.drop_all(self.connection)

    def teardown_method(self, method):
        self.session.rollback()
        uow_leaks = versioning_manager.units_of_work
        session_map_leaks = versioning_manager.session_connection_map

        remove_versioning()
        QueryPool.queries = []
        versioning_manager.reset()

        self.session.close_all()
        self.session.expunge_all()
        self.drop_tables()
        self.engine.dispose()
        self.connection.close()

        assert not uow_leaks
        assert not session_map_leaks

    def create_models(self):
        class Article(self.Model):
            __tablename__ = 'article'
            __versioned__ = copy(self.options)

            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255), nullable=False)
            content = sa.Column(sa.UnicodeText)
            description = sa.Column(sa.UnicodeText)

            # Dynamic column cotaining all text content data
            fulltext_content = column_property(name + content + description)

        class Tag(self.Model):
            __tablename__ = 'tag'
            __versioned__ = copy(self.options)

            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255))
            article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
            article = sa.orm.relationship(Article, backref='tags')

        self.Article = Article
        self.Tag = Tag
class TestFlaskPlugin(TestCase):
    plugins = [FlaskPlugin()]
    transaction_cls = TransactionFactory()
    user_cls = 'User'

    def setup_method(self, method):
        TestCase.setup_method(self, method)
        self.app = Flask(__name__)
        self.app.secret_key = 'secret'
        self.app.debug = True
        self.setup_views()
        login_manager = LoginManager()
        login_manager.init_app(self.app)
        self.client = self.app.test_client()
        self.context = self.app.test_request_context()
        self.context.push()

        @login_manager.user_loader
        def load_user(id):
            return self.session.query(self.User).get(id)

    def teardown_method(self, method):
        TestCase.teardown_method(self, method)
        self.context.pop()
        self.context = None
        self.client = None
        self.app = None

    def login(self, user):
        """
        Log in the user returned by :meth:`create_user`.

        :returns: the logged in user
        """
        with self.client.session_transaction() as s:
            s['_user_id'] = user.id
        return user

    def logout(self, user=None):
        with self.client.session_transaction() as s:
            s['_user_id'] = None

    def create_models(self):
        TestCase.create_models(self)

        class User(self.Model):
            __tablename__ = 'user'
            __versioned__ = {'base_classes': (self.Model, )}

            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255), nullable=False)

        self.User = User

    def setup_views(self):
        @self.app.route('/simple-flush')
        def test_simple_flush():
            article = self.Article()
            article.name = u'Some article'
            self.session.add(article)
            self.session.commit()
            return ''

        @self.app.route('/raw-sql-and-flush')
        def test_raw_sql_and_flush():
            self.session.execute(
                "INSERT INTO article (name) VALUES ('some article')")
            article = self.Article()
            article.name = u'Some article'
            self.session.add(article)
            self.session.flush()
            self.session.execute(
                "INSERT INTO article (name) VALUES ('some article')")
            self.session.commit()
            return ''

    def test_versioning_inside_request(self):
        user = self.User(name=u'Rambo')
        self.session.add(user)
        self.session.commit()
        self.login(user)
        self.client.get(url_for('.test_simple_flush'))

        article = self.session.query(self.Article).first()
        tx = article.versions[-1].transaction
        assert tx.user.id == user.id

    def test_raw_sql_and_flush(self):
        user = self.User(name=u'Rambo')
        self.session.add(user)
        self.session.commit()
        self.login(user)
        self.client.get(url_for('.test_raw_sql_and_flush'))
        assert (self.session.query(
            versioning_manager.transaction_cls).count() == 2)
Exemple #4
0
def test_versioning(native_versioning, versioning_strategy,
                    property_mod_tracking):
    transaction_column_name = 'transaction_id'
    end_transaction_column_name = 'end_transaction_id'
    plugins = [TransactionChangesPlugin(), TransactionMetaPlugin()]

    if property_mod_tracking:
        plugins.append(PropertyModTrackerPlugin())
    transaction_cls = TransactionFactory()
    user_cls = None

    Model = declarative_base()

    options = {
        'create_models': True,
        'native_versioning': native_versioning,
        'base_classes': (Model, ),
        'strategy': versioning_strategy,
        'transaction_column_name': transaction_column_name,
        'end_transaction_column_name': end_transaction_column_name,
    }

    make_versioned(options=options)

    dns = 'postgres://postgres@localhost/sqlalchemy_continuum_test'
    versioning_manager.plugins = plugins
    versioning_manager.transaction_cls = transaction_cls
    versioning_manager.user_cls = user_cls

    engine = create_engine(dns)

    # engine.echo = True

    class Article(Model):
        __tablename__ = 'article'
        __versioned__ = copy(options)

        id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
        name = sa.Column(sa.Unicode(255), nullable=False)
        content = sa.Column(sa.UnicodeText)
        description = sa.Column(sa.UnicodeText)

    class Tag(Model):
        __tablename__ = 'tag'
        __versioned__ = copy(options)

        id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
        article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
        article = sa.orm.relationship(Article, backref='tags')

    sa.orm.configure_mappers()

    connection = engine.connect()

    Model.metadata.create_all(connection)

    Session = sessionmaker(bind=connection)
    session = Session(autoflush=False)
    session.execute('CREATE EXTENSION IF NOT EXISTS hstore')

    Model.metadata.create_all(connection)

    start = time()

    for i in range(20):
        for i in range(20):
            session.add(Article(name=u'Article', tags=[Tag(), Tag()]))
        session.commit()

    print 'Testing with:'
    print '   native_versioning=%r' % native_versioning
    print '   versioning_strategy=%r' % versioning_strategy
    print '   property_mod_tracking=%r' % property_mod_tracking
    print colored('%r seconds' % (time() - start), 'red')

    Model.metadata.drop_all(connection)

    remove_versioning()
    versioning_manager.reset()

    session.close_all()
    session.expunge_all()
    Model.metadata.drop_all(connection)
    engine.dispose()
    connection.close()