コード例 #1
0
    def session_scope(self, query_func, need_commit=False):
        '''
        セッションスコープ

        :param function query_func: クエリーを行う関数
        :param bool need_commit: コミットを行う必要のある処理か?

        ``query_func`` の第1引数には SQLAlchemy のセッションオブジェクトが渡される。
        '''
        if not self.engine:
            self.engine = create_engine(self.connection_string)
        session = Session(bind=self.engine)

        try:
            result = query_func(session)
            if need_commit:
                session.commit()
            else:
                session.rollback()
        except:
            session.rollback()
            raise
        finally:
            session.close()

        return result
コード例 #2
0
class connect(object):
    def __init__(self):
        self.__session__ = None

    def __enter__(self):
        from sqlalchemy.orm import sessionmaker
        x = get_engine()
        Session = sessionmaker(bind=x)
        self.__session__ = Session()
        return self

    def open(self):
        from sqlalchemy.orm import sessionmaker
        x = get_engine()
        Session = sessionmaker(bind=x)
        self.__session__ = Session()

    def query(self, tbl):
        from .tables import BaseEntity
        from .queryable import __queryable__
        if isinstance(tbl, BaseEntity):
            return __queryable__(self.__session__, tbl)

    def __exit__(self, exc_type, exc_val, exc_tb):
        from sqlalchemy.orm.session import Session
        if isinstance(self.__session__, Session):
            self.__session__.close()
コード例 #3
0
def legacy_database_fixer(database_path, database, database_name,
                          database_exists):
    database_directory = os.path.dirname(database_path)
    old_database_path = database_path
    old_filename = os.path.basename(old_database_path)
    new_filename = f"Pre_Alembic_{old_filename}"
    pre_alembic_path = os.path.join(database_directory, new_filename)
    pre_alembic_database_exists = False
    if os.path.exists(pre_alembic_path):
        database_path = pre_alembic_path
        pre_alembic_database_exists = True
    datas = []
    if database_exists:
        Session, engine = db_helper.create_database_session(database_path)
        database_session = Session()
        result = inspect(engine).has_table("alembic_version")
        if not result:
            if not pre_alembic_database_exists:
                os.rename(old_database_path, pre_alembic_path)
                pre_alembic_database_exists = True
    if pre_alembic_database_exists:
        Session, engine = db_helper.create_database_session(pre_alembic_path)
        database_session = Session()
        api_table = database.api_table()
        media_table = database.media_table()
        legacy_api_table = api_table.legacy(database_name)
        legacy_media_table = media_table.legacy()
        result = database_session.query(legacy_api_table)
        post_db = result.all()
        for post in post_db:
            post_id = post.id
            created_at = post.created_at
            new_item = {}
            new_item["post_id"] = post_id
            new_item["text"] = post.text
            new_item["price"] = post.price
            new_item["paid"] = post.paid
            new_item["postedAt"] = created_at
            new_item["medias"] = []
            result2 = database_session.query(legacy_media_table)
            media_db = result2.filter_by(post_id=post_id).all()
            for media in media_db:
                new_item2 = {}
                new_item2["media_id"] = media.id
                new_item2["post_id"] = media.post_id
                new_item2["links"] = [media.link]
                new_item2["directory"] = media.directory
                new_item2["filename"] = media.filename
                new_item2["size"] = media.size
                new_item2["media_type"] = media.media_type
                new_item2["downloaded"] = media.downloaded
                new_item2["created_at"] = created_at
                new_item["medias"].append(new_item2)
            datas.append(new_item)
        print
        database_session.close()
        export_sqlite2(old_database_path,
                       datas,
                       database_name,
                       legacy_fixer=True)
コード例 #4
0
def addModelCompartmentalizedComponent(modelId, compartmentalizedComponentId,
                                       compartmentId):
    session = Session()
    if not session.query(ModelCompartmentalizedComponent).filter(
            ModelCompartmentalizedComponent.model_id == modelId).filter(
                ModelCompartmentalizedComponent.compartmentalized_component_id
                == compartmentalizedComponentId).filter(
                    ModelCompartmentalizedComponent.compartment_id ==
                    compartmentId).count():
        try:
            model = session.query(Model).filter(Model.id == modelId).one()
        except:
            print "model does not exist in database"
            raise
        try:
            cc = session.query(CompartmentalizedComponent).filter(
                CompartmentalizedComponent.id ==
                compartmentalizedComponentId).one()
        except:
            print "compartmentalized component does not exist in database"
            raise
        try:
            compartment = session.query(Compartment).filter(
                Compartment.id == compartmentId).one()
        except:
            print "compartment does not exist in database"
            raise
        mcc = ModelCompartmentalizedComponent(
            model_id=model.id,
            compartmentalized_component_id=cc.id,
            compartment_id=compartment.id)
        session.add(mcc)
        session.commit()
        session.close()
        return mcc
コード例 #5
0
 def test_check_auth_redis_miss(self):
     db_handler = self.auth_handler.db_handler
     db_conn = db_handler.getEngine().connect()
     db_txn = db_conn.begin()
     try:
         db_session = Session(bind=db_conn)
         try:
             account = Account(auth_id='some_auth_id', username='******')
             db_session.add(account)
             db_session.flush()
             phonenumber = PhoneNumber(number='9740171794',
                                       account_id=account.id)
             db_session.add(phonenumber)
             db_session.commit()
             self.auth_handler.redis_client.hget.return_value = None
             self.auth_handler.redis_client.pipeline.return_value = redis_pipeline_mock = MagicMock(
             )
             status, phonenums = self.auth_handler._check_auth(
                 'some_user', 'some_auth_id', db_session)
             self.assertTrue(status)
             self.assertEquals(set(['9740171794']), phonenums)
             self.assertEquals(redis_pipeline_mock.hset.call_count, 1)
             redis_pipeline_mock.hset.assert_called_with(
                 self.auth_handler._REDIS_AUTH_HASH, 'some_user',
                 'some_auth_id')
             self.assertEquals(redis_pipeline_mock.sadd.call_count, 1)
             redis_pipeline_mock.sadd.assert_called_with(
                 'some_user', '9740171794')
             self.assertEquals(redis_pipeline_mock.execute.call_count, 1)
         finally:
             db_session.close()
     finally:
         db_txn.rollback()
         db_conn.close()
コード例 #6
0
class TrialModelTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        global transaction, connection, engine

        # Connect to the database and create the schema within a transaction
        engine = create_engine(TEST_DATABASE_URI)
        connection = engine.connect()
        transaction = connection.begin()
        Trial.metadata.create_all(connection)

        # Load test trials fixtures from xml files
        nct_ids = ['NCT02034110', 'NCT00001160', 'NCT00001163']
        cls.trials = load_sample_trials(nct_ids)

    @classmethod
    def tearDownClass(cls):
        # Roll back the top level transaction and disconnect from the database
        transaction.rollback()
        connection.close()
        engine.dispose()

    def setUp(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
        self.__transaction.rollback()

    def test_add(self):
        trial = Trial(ct_dict=self.trials[0])
        self.session.add(trial)
コード例 #7
0
    def test_session_scope(self) -> Session:
        """Provides a transactional scope around a series of operations that will be rolled back at the end.

        Note:
            SQLite and MySQL storage engine MyISAM do not support rollback transactions, so all the
            modifications performed to the database will persist.

        """
        # Connect to the database
        connection = self.connect()
        # Begin a non-ORM transaction
        transaction = connection.begin()
        # Bind an individual Session to the connection
        session = Session(bind=connection)
        # If the database supports SAVEPOINT, starting a savepoint will allow to also use rollback
        connection.begin_nested()
        # Define a new transaction event
        @event.listens_for(session, "after_transaction_end")
        def end_savepoint(session, transaction):  # pylint: disable=unused-variable
            if not connection.in_nested_transaction():
                connection.begin_nested()

        try:
            yield session
        finally:
            # Whatever happens, make sure the session and connection are closed, rolling back everything done
            # with the session (including calls to commit())
            session.close()
            transaction.rollback()
            connection.close()
コード例 #8
0
ファイル: update.py プロジェクト: nstoik/fd_device
def get_grainbin_updates(session: Session = None) -> list:
    """Get all grainbin updates as a list for each grainbin."""

    close_session = False
    if not session:
        close_session = True
        session = get_session()

    all_updates: list = []

    grainbins: list[Grainbin] = session.query(Grainbin).all()

    all_busses = get_all_busses()

    for grainbin in grainbins:
        if grainbin.bus_number_string in all_busses:
            update = get_indivudual_grainbin_update(grainbin)
            all_updates.append(update)
        else:
            LOGGER.warning(
                f"Bus {grainbin.bus_number_string} not currently connected when trying to create update."
            )

    session.commit()
    if close_session:
        session.close()

    return all_updates
コード例 #9
0
    def test_session_scope(self) -> Session:
        """Provides a transactional scope around a series of operations that will be rolled back at the end.

        Note:
            MySQL storage engine MyISAM does not support rollback transactions, so all the modifications
            performed to the database will persist.

        """
        # Connect to the database
        connection = self.connect()
        # Begin a non-ORM transaction
        transaction = connection.begin()
        # Bind an individual Session to the connection
        session = Session(bind=connection)
        # Start the session in a SAVEPOINT
        session.begin_nested()
        # Define a new transaction event
        @event.listens_for(session, "after_transaction_end")
        def restart_savepoint(session, transaction):  # pylint: disable=unused-variable
            """Reopen a SAVEPOINT whenever the previous one ends."""
            if transaction.nested and not transaction._parent.nested:  # pylint: disable=protected-access
                # Ensure that state is expired the same way session.commit() at the top level normally does
                session.expire_all()
                session.begin_nested()

        try:
            yield session
        finally:
            # Whatever happens, make sure the session and connection are closed, rolling back everything done
            # with the session (including calls to commit())
            session.close()
            transaction.rollback()
            connection.close()
コード例 #10
0
ファイル: test_base.py プロジェクト: digicatapult/fornax
class TestCaseDB(unittest.TestCase):
    """ A base test for setting up and tearing down the databse """
    @classmethod
    def setUpClass(cls):
        """ Create the engine, create one connection
        and start a transaction """
        engine = create_engine('sqlite://', echo=False)

        connection = engine.connect()
        cls._engine = engine
        cls._connection = connection
        cls.__transaction = connection.begin()
        Base.metadata.create_all(connection)

    @classmethod
    def tearDownClass(cls):
        """ tear down the top level transaction """
        cls.__transaction.rollback()
        cls._connection.close()
        cls._engine.dispose()

    def setUp(self):
        """ create a new session and a nested transaction """
        self._transaction = self._connection.begin_nested()
        self.session = Session(self._connection)

    def tearDown(self):
        """ rollback the nested transaction """
        self._transaction.rollback()
        self.session.close()
コード例 #11
0
class ServiceTest(unittest.TestCase):
    def setup(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def teardown(self):
        self.session.close()
        self.__transaction.rollback()
コード例 #12
0
def session(connection, transaction, monkeypatch):
    _transaction = connection.begin_nested()
    _session = Session(connection, autoflush=False, autocommit=False)
    # Make sure we do not commit
    monkeypatch.setattr(_session, 'commit', _session.flush)
    yield _session
    _session.close()
    _transaction.rollback()
コード例 #13
0
ファイル: conftest.py プロジェクト: tiebanchuang/OSMNames
def session(engine):
    session = Session(engine)

    yield session

    session.close()

    exec_sql("SELECT truncate_tables('osm_test')")
コード例 #14
0
ファイル: csv2db.py プロジェクト: pedluini/Shool-projects
def insert2db(name, address, postcode, city, date, type, businessid):
    new_row = License(name=name, address=address, postcode=postcode, city=city, license_granting_date=date, license_type=type, business_id=businessid)
    db = Session()
    db.add(new_row)
    db.commit()
    id = new_row.id
    db.close()
    return id
コード例 #15
0
ファイル: conftest.py プロジェクト: geometalab/OSMNames
def session(engine):
    session = Session(engine)

    yield session

    session.close()

    exec_sql("SELECT truncate_tables('osm_test')")
コード例 #16
0
ファイル: test.py プロジェクト: mrbitsdcf/banco_imobiliario
class DatabaseTest(object):
    def setup(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def teardown(self):
        self.session.close()
        self.__transaction.rollback()
コード例 #17
0
class DatabaseTest(unittest.TestCase):
    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

    def tearDown(self):
        self.trans.rollback()
        self.session.close()
コード例 #18
0
class DatabaseTest:
    def setup(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def teardown(self):
        self.session.close()
        self.__transaction.rollback()
コード例 #19
0
ファイル: __init__.py プロジェクト: llazzaro/pystock
class DatabaseTest(unittest.TestCase):

    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

    def tearDown(self):
        self.trans.rollback()
        self.session.close()
コード例 #20
0
def upgrade():
    session = Session(bind=op.get_bind())
    meta = sa.MetaData()
    meta.reflect(bind=op.get_bind())
    Base = automap_base(metadata=meta)
    Base.prepare()
    Task = Base.classes.tasks

    jobs = sa.Table('jobs', meta)

    new_jobs = []
    tasks = session.query(Task).all()
    for task in tasks:
        new_jobs.append(
            {'name': 'Job from Task {}'.format(str(task.id)),
             'description': 'Job auto-created from task with id: {}'.format(str(task.id)),
             'user_id': task.user_id,
             'status': task.status,
             '_start_at': task.spawn_at,
             '_stop_at': task.terminate_at}
        )

    op.bulk_insert(jobs, new_jobs)

    naming_convention = {
        "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
    }

    with op.batch_alter_table("tasks", naming_convention=naming_convention) as batch_op:
        batch_op.add_column(sa.Column('job_id', sa.Integer, nullable=True))
        batch_op.create_foreign_key("fk_tasks_job_id_jobs", 'jobs', ['job_id'], ['id'], ondelete='CASCADE')
        batch_op.drop_constraint("fk_tasks_user_id_users", type_="foreignkey")
        batch_op.drop_column('user_id')
        batch_op.drop_column('spawn_at')
        batch_op.drop_column('terminate_at')

    session.commit()
    session.close()

    session = Session(bind=op.get_bind())
    meta = sa.MetaData()
    meta.reflect(bind=op.get_bind())
    Base = automap_base(metadata=meta)
    Base.prepare()
    Job = Base.classes.jobs

    jobs = session.query(Job).all()
    tasks = sa.Table('tasks', meta)

    for job in jobs:
        if job.name[:-2] == 'Job from Task':
            op.execute(tasks.update()
                .where(tasks.columns.get('id') == job.name[-1]).values(job_id=job.id))

    session.commit()
    session.close()
コード例 #21
0
def addCompartment(name):
    session = Session()
    if session.query(Compartment).filter(
            Compartment.name == name).count() == 1:
        print "already exists in database"
    else:
        compartment = Compartment(name=name)
        session.add(compartment)
        session.commit()
        session.close()
コード例 #22
0
ファイル: interface.py プロジェクト: AlexWylie/marcotti
 def create_session(self):
     session = Session(self.connection)
     try:
         yield session
         session.commit()
     except Exception as ex:
         session.rollback()
         raise ex
     finally:
         session.close()
コード例 #23
0
ファイル: app.py プロジェクト: SafeBlues/backend
def session_scope():
    session = Session(get_engine())
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
コード例 #24
0
ファイル: db.py プロジェクト: leejh3224/bitrush
def session_scope(session: Session):
    """Provide a transactional scope around a series of operations."""
    try:
        yield session
        session.commit()
    except Exception:
        session.rollback()
        raise
    finally:
        session.close()
コード例 #25
0
def deleteGene(geneId):
    session = Session()
    try:
        gene = session.query(Gene).filter(Gene.id == geneId).one()
    except:
        print "gene does not exist in database"
        raise
    session.delete(gene)
    session.commit()
    session.close()
コード例 #26
0
ファイル: c_database.py プロジェクト: ieasysoft/imetadata
 def session_close(self, session: Session):
     """
     session必须手工在finally里关闭
     :param session:
     :return:
     """
     eng = session.get_bind()
     session.close()
     if eng is not None:
         eng.dispose()
コード例 #27
0
ファイル: interface.py プロジェクト: AlexWylie/marcotti
 def create_session(self):
     session = Session(self.connection)
     try:
         yield session
         session.commit()
     except Exception as ex:
         session.rollback()
         raise ex
     finally:
         session.close()
コード例 #28
0
def session_scope(isolation_level=None):
    session = Session(get_engine(isolation_level=isolation_level))
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
コード例 #29
0
def deleteModel(modelId):
    session = Session()
    try:
        model = session.query(Model).filter(Model.id == modelId).one()
    except:
        print "model does not exist in database"
        raise
    session.delete(model)
    session.commit()
    session.close()
コード例 #30
0
def addReaction(name, long_name, type, notes):
    session = Session()
    if not session.query(Reaction).filter(Reaction.id == reactionId).count():
        reaction = Reaction(name=name,
                            long_name=long_name,
                            type=type,
                            notes=notes)
        session.add(reaction)
        session.commit()
        session.close()
        return reaction
コード例 #31
0
def updateModel(modelId, modelDict=None):
    session = Session()
    try:
        model = session.query(Model).filter(Model.id == modelId).one()
    except:
        print "model does not exist in database"
        raise
    if modelDict is not None and modelDict:
        session.query(Model).filter(Model.id == modelId).update(modelDict)
        session.commit()
        session.close()
コード例 #32
0
ファイル: tests.py プロジェクト: marshallgallatin/cs373-idb
class DatabaseTest(TestCase):
    """
    Base class for our tests that involve the database.
    """
    def setUp(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
        self.__transaction.rollback()
コード例 #33
0
class DBStorage:
    """ new class for DBStorage engine """
    __engine = None
    __session = None
    __list_class = [State, User, City, Amenity, Review, Place]

    def __init__(self):
        """ method documentation """
        self.__engine = create_engine('mysql+mysqldb://{}:{}@{}/{}'.format(
            environ['HBNB_MYSQL_USER'], environ['HBNB_MYSQL_PWD'],
            environ['HBNB_MYSQL_HOST'], environ['HBNB_MYSQL_DB']),
                                      pool_pre_ping=True)
        if 'HBNB_ENV' in environ and environ['HBNB_ENV'] == 'test':
            Base.metadata.drop_all(bind=self.__engine)

    def all(self, cls=None):
        """ method documentation """
        new_dic = {}
        if cls is not None:
            all_obj = self.__session.query(eval(cls))
            for obj in all_obj:
                key = ".".join([cls, obj.id])
                new_dic.update({key: obj})
        else:
            for cl in self.__list_class:
                all_obj = self.__session.query(cl)
                for obj in all_obj:
                    key = ".".join([cls, obj.id])
                    new_dic.update({key: obj})
        return new_dic

    def new(self, obj):
        """ method documentation for new """
        self.__session.add(obj)

    def save(self):
        """ method documentation for save """
        self.__session.commit()

    def delete(self, obj=None):
        """ method documentation for delete """
        if obj is not None:
            self.__session.delete(obj)

    def reload(self):
        """ method documentation for reload """
        Base.metadata.create_all(self.__engine)
        current_se = sessionmaker(bind=self.__engine, expire_on_commit=False)
        Session = scoped_session(current_se)
        self.__session = Session()

    def close(self):
        """ for close the sqlalchemy session """
        self.__session.close()
コード例 #34
0
def deleteReaction(reactionId):
    session = Session()
    try:
        reaction = session.query(Reaction).filter(
            Reaction.id == reactionId).one()
    except:
        print "reaction does not exist in database"
        raise
    session.delete(reaction)
    session.commit()
    session.close()
コード例 #35
0
def db_session(test_db):
    connection = test_db.connect()
    transaction = connection.begin()
    session = Session(autocommit=False, autoflush=False, bind=connection)
    try:
        session.begin_nested()
        yield session
    finally:
        session.close()
        transaction.rollback()
        connection.close()
コード例 #36
0
def deleteCompartment(compartmentId):
    session = Session()
    try:
        c = session.query(Comparmtnet).filter(
            Compartment.id == compartmentId).one()
    except:
        print "compartment does not exist in database"
        raise
    session.delete(c)
    session.commit()
    session.close()
コード例 #37
0
ファイル: tasks.py プロジェクト: Sentimentron/senbot
    def run(self, keyword):
        kw = None
        session = Session(bind = self.engine)
        it = session.query(Keyword).filter_by(word = keyword)
        try:
            kw = it.one()
        except NoResultFound:
            return None 

        ret = kw.id
        session.close()
        return ret
コード例 #38
0
    def test_session_with_external_transaction(self):
        conn = self.engine.connect()
        t = conn.begin()
        session = Session(bind=conn)

        article = self.Article(name=u'My Session Article')
        session.add(article)
        session.flush()

        session.close()
        t.rollback()
        conn.close()
コード例 #39
0
ファイル: mysql.py プロジェクト: liuzelei/walis
 def wrapper(*args, **kwargs):
     ret = func(*args, **kwargs)
     session = Session()
     # tmp
     session._model_changes = {}
     try:
         session.commit()
     except SQLAlchemyError as se:
         session.rollback()
         raise_server_exc(DATABASE_UNKNOWN_ERROR, exc=se)
     finally:
         session.close()
     return ret
コード例 #40
0
ファイル: interface.py プロジェクト: soccermetrics/marcotti
    def create_session(self):
        """
        Create a session context that communicates with the database.

        Commits all changes to the database before closing the session, and if an exception is raised,
        rollback the session.
        """
        session = Session(self.connection)
        try:
            yield session
            session.commit()
        except Exception as ex:
            session.rollback()
            raise ex
        finally:
            session.close()
コード例 #41
0
ファイル: datasets.py プロジェクト: paulofreitas/dtb-ibge
    def transaction(cls,
                    session: db_session.Session) -> Iterator[db_session.Session]:
        """
        Provides a transactional context-based database session.

        Args:
            session: The database session instance to wrap

        Yields:
            The wrapped database session instance
        """
        try:
            yield session
            session.commit()
        except:
            session.rollback()
            raise
        finally:
            session.close()
コード例 #42
0
ファイル: base.py プロジェクト: soccermetrics/marcotti-mls
    def create_session(self):
        """
        Create a session context that communicates with the database.

        Commits all changes to the database before closing the session, and if an exception is raised,
        rollback the session.
        """
        session = Session(self.connection)
        logger.info("Create session {0} with {1}".format(
            id(session), self._public_db_uri(str(self.engine.url))))
        try:
            yield session
            session.commit()
            logger.info("Commit transactions to database")
        except Exception:
            session.rollback()
            logger.exception("Database transactions rolled back")
        finally:
            logger.info("Session {0} with {1} closed".format(
                id(session), self._public_db_uri(str(self.engine.url))))
            session.close()
コード例 #43
0
ファイル: db.py プロジェクト: LiberTang0/Python-Carepoint
class DatabaseTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls, ):
        cls.engine = Db(drv=Db.SQLITE)
        cls.connection = cls.engine.connect()
        cls.transaction = cls.connection.begin()
        Carepoint.BASE.metadata.create_all(cls.connection)

    @classmethod
    def tearDownClass(cls, ):
        # cls.transaction.rollback()
        cls.connection.close()
        cls.engine.dispose()

    def setUp(self, ):
        self.__transaction = self.connection.begin_nested()
        self.session = Session(self.connection)

    def tearDown(self, ):
        self.session.close()
コード例 #44
0
ファイル: indexing.py プロジェクト: sfermigier/yaka-core
def index_update(class_name, items):
  """ items: dict of model class name => list of (operation, primary key)
  """
  cls_registry = dict([(cls.__name__, cls) for cls in service.indexed_classes])
  model_class = cls_registry.get(class_name)

  if model_class is None:
    raise ValueError("Invalid class: {}".format(class_name))

  index = service.index_for_model_class(model_class)
  primary_field = model_class.search_query.primary
  indexed_fields = model_class.whoosh_schema.names()

  session = Session(bind=db.session.get_bind(None, None))
  query = session.query(model_class)

  with AsyncWriter(index) as writer:
    for change_type, model_pk in items:
      if model_pk is None:
        continue
      # delete everything. stuff that's updated or inserted will get
      # added as a new doc. Could probably replace this with a whoosh
      # update.
      writer.delete_by_term(primary_field, unicode(model_pk))

      if change_type in ("new", "changed"):
        model = query.get(model_pk)
        if model is None:
          # deleted after task queued, but before task run
          continue

        # Hack: Load lazy fields
        # This prevents a transaction error in make_document
        for key in indexed_fields:
          getattr(model, key)

        document = service.make_document(model, indexed_fields, primary_field)
        writer.add_document(**document)

  session.close()
コード例 #45
0
ファイル: utils.py プロジェクト: tmwilder/d2modeling
class DatabaseTest(object):
    """
        Creates and wipes a fresh test db for each test.
        We can afford the overhead until we have many more tests.

    """
    def setUp(self):
        self.db_path = 'd2modeling_test.db'
        self.engine = create_engine('sqlite:///{}'.format(self.db_path))
        schema.Base.metadata.bind = self.engine
        self.connection = self.engine.connect()
        self.transaction = self.connection.begin()
        schema.Base.metadata.create_all(self.connection)
        abspath = os.path.realpath(self.db_path)
        self.db_api_2_conn = sqlite3.connect(abspath)
        self.session = Session(self.connection)

    def tearDown(self):
        self.session.close()
        self.connection.close()
        self.engine.dispose()
        self.db_api_2_conn.close()
        os.remove(self.db_path)
コード例 #46
0
class testTradingCenter(unittest.TestCase):

    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

        currency = Currency(name='Pesos', code='ARG')
        self.exchange = Exchange(name='Merval', code='MERV', currency=currency)
        self.owner = Owner(name='poor owner')
        self.broker = Broker(name='broker1')
        self.account = Account(owner=self.owner, broker=self.broker)
        self.account.deposit(Money(amount=10000, currency=currency))

    def tearDown(self):
        self.trans.rollback()
        self.session.close()

    def test_open_orders_by_order_id(self):
        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        order=tc.open_order_by_id(order1.id)
        self.assertEquals(order1, order)

        order=tc.open_order_by_id(100)
        self.assertEquals(None, order)

    def testGetOpenOrdersBySymbol(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        orders=tc.open_orders_by_symbol('symbol')
        self.assertEquals([order1, order2], list(orders))

    def testCancelOrder(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        order1.cancel()
        self.assertEquals([order2], tc.open_orders)
        self.assertEquals([order1], tc.cancel_orders)
        self.assertEquals(CancelOrderStage, type(order1.current_stage))

        order2.cancel()
        self.assertEquals([], tc.open_orders)
        self.assertEquals([order1, order2], tc.cancel_orders)

    def testCancelAllOpenOrders(self):
        security=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=security, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=security, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        tc.cancel_all_open_orders()

        self.assertEquals([], tc.open_orders)

    def testConsume(self):
        pass

    def testPostConsume(self):
        pass

    def testCreateAccountWithMetrix(self):
        pass
コード例 #47
0
class DatabaseTest(object):

    engine = None
    connection = None

    @classmethod
    def get_database_connection(cls):
        url = Configuration.database_url()
        engine, connection = SessionManager.initialize(url)

        return engine, connection

    @classmethod
    def setup_class(cls):
        # Initialize a temporary data directory.
        cls.engine, cls.connection = cls.get_database_connection()
        cls.old_data_dir = Configuration.data_directory
        cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp")
        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir

        # Avoid CannotLoadConfiguration errors related to CDN integrations.
        Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get(
            Configuration.INTEGRATIONS, {}
        )
        Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {}

    @classmethod
    def teardown_class(cls):
        # Destroy the database connection and engine.
        cls.connection.close()
        cls.engine.dispose()

        if cls.tmp_data_dir.startswith("/tmp"):
            logging.debug("Removing temporary directory %s" % cls.tmp_data_dir)
            shutil.rmtree(cls.tmp_data_dir)

        else:
            logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir)

        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir

    def setup(self, mock_search=True):
        # Create a new connection to the database.
        self._db = Session(self.connection)
        self.transaction = self.connection.begin_nested()

        # Start with a high number so it won't interfere with tests that search for an age or grade
        self.counter = 2000

        self.time_counter = datetime(2014, 1, 1)
        self.isbns = [
            "9780674368279", "0636920028468", "9781936460236", "9780316075978"
        ]
        if mock_search:
            self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex)
            self.search_mock.start()
        else:
            self.search_mock = None

        # TODO:  keeping this for now, but need to fix it bc it hits _isbn,
        # which pops an isbn off the list and messes tests up.  so exclude
        # _ functions from participating.
        # also attempt to stop nosetest showing docstrings instead of function names.
        #for name, obj in inspect.getmembers(self):
        #    if inspect.isfunction(obj) and obj.__name__.startswith('test_'):
        #        obj.__doc__ = None


    def teardown(self):
        # Close the session.
        self._db.close()

        # Roll back all database changes that happened during this
        # test, whether in the session that was just closed or some
        # other session.
        self.transaction.rollback()

        # Remove any database objects cached in the model classes but
        # associated with the now-rolled-back session.
        Collection.reset_cache()
        ConfigurationSetting.reset_cache()
        DataSource.reset_cache()
        DeliveryMechanism.reset_cache()
        ExternalIntegration.reset_cache()
        Genre.reset_cache()
        Library.reset_cache()

        # Also roll back any record of those changes in the
        # Configuration instance.
        for key in [
                Configuration.SITE_CONFIGURATION_LAST_UPDATE,
                Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
        ]:
            if key in Configuration.instance:
                del(Configuration.instance[key])

        if self.search_mock:
            self.search_mock.stop()

    def time_eq(self, a, b):
        "Assert that two times are *approximately* the same -- within 2 seconds."
        if a < b:
            delta = b-a
        else:
            delta = a-b
        total_seconds = delta.total_seconds()
        assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds)

    def shortDescription(self):
        return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.

    @property
    def _id(self):
        self.counter += 1
        return self.counter

    @property
    def _str(self):
        return unicode(self._id)

    @property
    def _time(self):
        v = self.time_counter
        self.time_counter = self.time_counter + timedelta(days=1)
        return v

    @property
    def _isbn(self):
        return self.isbns.pop()

    @property
    def _url(self):
        return "http://foo.com/" + self._str

    def _patron(self, external_identifier=None, library=None):
        external_identifier = external_identifier or self._str
        library = library or self._default_library
        return get_one_or_create(
            self._db, Patron, external_identifier=external_identifier,
            library=library
        )[0]

    def _contributor(self, sort_name=None, name=None, **kw_args):
        name = sort_name or name or self._str
        return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args)

    def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None):
        if foreign_id:
            id = foreign_id
        else:
            id = self._str
        return Identifier.for_foreign_id(self._db, identifier_type, id)[0]

    def _edition(self, data_source_name=DataSource.GUTENBERG,
                 identifier_type=Identifier.GUTENBERG_ID,
                 with_license_pool=False, with_open_access_download=False,
                 title=None, language="eng", authors=None, identifier_id=None,
                 series=None, collection=None, publicationDate=None
    ):
        id = identifier_id or self._str
        source = DataSource.lookup(self._db, data_source_name)
        wr = Edition.for_foreign_id(
            self._db, source, identifier_type, id)[0]
        if not title:
            title = self._str
        wr.title = unicode(title)
        wr.medium = Edition.BOOK_MEDIUM
        if series:
            wr.series = series
        if language:
            wr.language = language
        if authors is None:
            authors = self._str
        if isinstance(authors, basestring):
            authors = [authors]
        if authors != []:
            wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE)
            wr.author = unicode(authors[0])
        for author in authors[1:]:
            wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE)
        if publicationDate:
            wr.published = publicationDate

        if with_license_pool or with_open_access_download:
            pool = self._licensepool(
                wr, data_source_name=data_source_name,
                with_open_access_download=with_open_access_download,
                collection=collection
            )

            pool.set_presentation_edition()
            return wr, pool
        return wr

    def _work(self, title=None, authors=None, genre=None, language=None,
              audience=None, fiction=True, with_license_pool=False,
              with_open_access_download=False, quality=0.5, series=None,
              presentation_edition=None, collection=None, data_source_name=None):
        """Create a Work.

        For performance reasons, this method does not generate OPDS
        entries or calculate a presentation edition for the new
        Work. Tests that rely on this information being present
        should call _slow_work() instead, which takes more care to present
        the sort of Work that would be created in a real environment.
        """
        pools = []
        if with_open_access_download:
            with_license_pool = True
        language = language or "eng"
        title = unicode(title or self._str)
        audience = audience or Classifier.AUDIENCE_ADULT
        if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name:
            # TODO: This is necessary because Gutenberg's childrens books
            # get filtered out at the moment.
            data_source_name = DataSource.OVERDRIVE
        elif not data_source_name:
            data_source_name = DataSource.GUTENBERG
        if fiction is None:
            fiction = True
        new_edition = False
        if not presentation_edition:
            new_edition = True
            presentation_edition = self._edition(
                title=title, language=language,
                authors=authors,
                with_license_pool=with_license_pool,
                with_open_access_download=with_open_access_download,
                data_source_name=data_source_name,
                series=series,
                collection=collection,
            )
            if with_license_pool:
                presentation_edition, pool = presentation_edition
                if with_open_access_download:
                    pool.open_access = True
                pools = [pool]
        else:
            pools = presentation_edition.license_pools
        work, ignore = get_one_or_create(
            self._db, Work, create_method_kwargs=dict(
                audience=audience,
                fiction=fiction,
                quality=quality), id=self._id)
        if genre:
            if not isinstance(genre, Genre):
                genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
            work.genres = [genre]
        work.random = 0.5
        work.set_presentation_edition(presentation_edition)

        if pools:
            # make sure the pool's presentation_edition is set,
            # bc loan tests assume that.
            if not work.license_pools:
                for pool in pools:
                    work.license_pools.append(pool)

            for pool in pools:
                pool.set_presentation_edition()

            # This is probably going to be used in an OPDS feed, so
            # fake that the work is presentation ready.
            work.presentation_ready = True
            work.calculate_opds_entries(verbose=False)

        return work

    def add_to_materialized_view(self, works, true_opds=False):
        """Make sure all the works in `works` show up in the materialized view.

        :param true_opds: Generate real OPDS entries for each each work,
        rather than faking it.
        """
        if not isinstance(works, list):
            works = [works]
        for work in works:
            if true_opds:
                work.calculate_opds_entries(verbose=False)
            else:
                work.presentation_ready = True
                work.simple_opds_entry = "<entry>an entry</entry>"
        self._db.commit()
        SessionManager.refresh_materialized_views(self._db)

    def _lane(self, display_name=None, library=None,
              parent=None, genres=None, languages=None,
              fiction=None
    ):
        display_name = display_name or self._str
        library = library or self._default_library
        lane, is_new = create(
            self._db, Lane,
            library=library,
            parent=parent, display_name=display_name,
            fiction=fiction
        )
        if is_new and parent:
            lane.priority = len(parent.sublanes)-1
        if genres:
            if not isinstance(genres, list):
                genres = [genres]
            for genre in genres:
                if isinstance(genre, basestring):
                    genre, ignore = Genre.lookup(self._db, genre)
                lane.genres.append(genre)
        if languages:
            if not isinstance(languages, list):
                languages = [languages]
            lane.languages = languages
        return lane

    def _slow_work(self, *args, **kwargs):
        """Create a work that closely resembles one that might be found in the
        wild.

        This is significantly slower than _work() but more reliable.
        """
        work = self._work(*args, **kwargs)
        work.calculate_presentation_edition()
        work.calculate_opds_entries(verbose=False)
        return work

    def _add_generic_delivery_mechanism(self, license_pool):
        """Give a license pool a generic non-open-access delivery mechanism."""
        data_source = license_pool.data_source
        identifier = license_pool.identifier
        content_type = Representation.EPUB_MEDIA_TYPE
        drm_scheme = DeliveryMechanism.NO_DRM
        LicensePoolDeliveryMechanism.set(
            data_source, identifier, content_type, drm_scheme,
            RightsStatus.IN_COPYRIGHT
        )

    def _coverage_record(self, edition, coverage_source, operation=None,
        status=CoverageRecord.SUCCESS, collection=None, exception=None,
    ):
        if isinstance(edition, Identifier):
            identifier = edition
        else:
            identifier = edition.primary_identifier
        record, ignore = get_one_or_create(
            self._db, CoverageRecord,
            identifier=identifier,
            data_source=coverage_source,
            operation=operation,
            collection=collection,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
                exception=exception,
            )
        )
        return record

    def _work_coverage_record(self, work, operation=None,
                              status=CoverageRecord.SUCCESS):
        record, ignore = get_one_or_create(
            self._db, WorkCoverageRecord,
            work=work,
            operation=operation,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
            )
        )
        return record

    def _licensepool(self, edition, open_access=True,
                     data_source_name=DataSource.GUTENBERG,
                     with_open_access_download=False,
                     set_edition_as_presentation=False,
                     collection=None):
        source = DataSource.lookup(self._db, data_source_name)
        if not edition:
            edition = self._edition(data_source_name)
        collection = collection or self._default_collection
        pool, ignore = get_one_or_create(
            self._db, LicensePool,
            create_method_kwargs=dict(
                open_access=open_access),
            identifier=edition.primary_identifier,
            data_source=source,
            collection=collection,
            availability_time=datetime.utcnow()
        )

        if set_edition_as_presentation:
            pool.presentation_edition = edition

        if with_open_access_download:
            pool.open_access = True
            url = "http://foo.com/" + self._str
            media_type = MediaTypes.EPUB_MEDIA_TYPE
            link, new = pool.identifier.add_link(
                Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
                source, media_type
            )

            # Add a DeliveryMechanism for this download
            pool.set_delivery_mechanism(
                media_type,
                DeliveryMechanism.NO_DRM,
                RightsStatus.GENERIC_OPEN_ACCESS,
                link.resource,
            )

            representation, is_new = self._representation(
                url, media_type, "Dummy content", mirrored=True)
            link.resource.representation = representation
        else:

            # Add a DeliveryMechanism for this licensepool
            pool.set_delivery_mechanism(
                MediaTypes.EPUB_MEDIA_TYPE,
                DeliveryMechanism.ADOBE_DRM,
                RightsStatus.UNKNOWN,
                None
            )
            pool.licenses_owned = pool.licenses_available = 1

        return pool

    def _license(self, pool, identifier=None, checkout_url=None, status_url=None,
                 expires=None, remaining_checkouts=None, concurrent_checkouts=None):
        identifier = identifier or self._str
        checkout_url = checkout_url or self._str
        status_url = status_url or self._str
        license, ignore = get_one_or_create(
            self._db, License, identifier=identifier, license_pool=pool,
            checkout_url=checkout_url,
            status_url=status_url, expires=expires,
            remaining_checkouts=remaining_checkouts,
            concurrent_checkouts=concurrent_checkouts,
        )
        return license

    def _representation(self, url=None, media_type=None, content=None,
                        mirrored=False):
        url = url or "http://foo.com/" + self._str
        repr, is_new = get_one_or_create(
            self._db, Representation, url=url)
        repr.media_type = media_type
        if media_type and content:
            repr.content = content
            repr.fetched_at = datetime.utcnow()
            if mirrored:
                repr.mirror_url = "http://foo.com/" + self._str
                repr.mirrored_at = datetime.utcnow()
        return repr, is_new

    def _customlist(self, foreign_identifier=None,
                    name=None,
                    data_source_name=DataSource.NYT, num_entries=1,
                    entries_exist_as_works=True
    ):
        data_source = DataSource.lookup(self._db, data_source_name)
        foreign_identifier = foreign_identifier or self._str
        now = datetime.utcnow()
        customlist, ignore = get_one_or_create(
            self._db, CustomList,
            create_method_kwargs=dict(
                created=now,
                updated=now,
                name=name or self._str,
                description=self._str,
                ),
            data_source=data_source,
            foreign_identifier=foreign_identifier
        )

        editions = []
        for i in range(num_entries):
            if entries_exist_as_works:
                work = self._work(with_open_access_download=True)
                edition = work.presentation_edition
            else:
                edition = self._edition(
                    data_source_name, title="Item %s" % i)
                edition.permanent_work_id="Permanent work ID %s" % self._str
            customlist.add_entry(
                edition, "Annotation %s" % i, first_appearance=now)
            editions.append(edition)
        return customlist, editions

    def _complaint(self, license_pool, type, source, detail, resolved=None):
        complaint, is_new = Complaint.register(
            license_pool,
            type,
            source,
            detail,
            resolved
        )
        return complaint

    def _credential(self, data_source_name=DataSource.GUTENBERG,
                    type=None, patron=None):
        data_source = DataSource.lookup(self._db, data_source_name)
        type = type or self._str
        patron = patron or self._patron()
        credential, is_new = Credential.persistent_token_create(
            self._db, data_source, type, patron
        )
        return credential

    def _external_integration(self, protocol, goal=None, settings=None,
                              libraries=None, **kwargs
    ):
        integration = None
        if not libraries:
            integration, ignore = get_one_or_create(
                self._db, ExternalIntegration, protocol=protocol, goal=goal
            )
        else:
            if not isinstance(libraries, list):
                libraries = [libraries]

            # Try to find an existing integration for one of the given
            # libraries.
            for library in libraries:
                integration = ExternalIntegration.lookup(
                    self._db, protocol, goal, library=libraries[0]
                )
                if integration:
                    break

            if not integration:
                # Otherwise, create a brand new integration specifically
                # for the library.
                integration = ExternalIntegration(
                    protocol=protocol, goal=goal,
                )
                integration.libraries.extend(libraries)
                self._db.add(integration)

        for attr, value in kwargs.items():
            setattr(integration, attr, value)

        settings = settings or dict()
        for key, value in settings.items():
            integration.set_setting(key, value)

        return integration

    def _delegated_patron_identifier(
            self, library_uri=None, patron_identifier=None,
            identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
            identifier=None
    ):
        """Create a sample DelegatedPatronIdentifier"""
        library_uri = library_uri or self._url
        patron_identifier = patron_identifier or self._str
        if callable(identifier):
            make_id = identifier
        else:
            if not identifier:
                identifier = self._str
            def make_id():
                return identifier
        patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
            self._db, library_uri, patron_identifier, identifier_type,
            make_id
        )
        return patron

    def _sample_ecosystem(self):
        """ Creates an ecosystem of some sample work, pool, edition, and author
        objects that all know each other.
        """
        # make some authors
        [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob")
        bob.family_name, bob.display_name = bob.default_names()
        [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice")
        alice.family_name, alice.display_name = alice.default_names()

        edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_std_ebooks.title = u"The Standard Ebooks Title"
        edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle"
        edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_git.title = u"The GItenberg Title"
        edition_git.subtitle = u"The GItenberg Subtitle"
        edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
        edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_gut.title = u"The GUtenberg Title"
        edition_gut.subtitle = u"The GUtenberg Subtitle"
        edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)

        work = self._work(presentation_edition=edition_git)

        for p in pool_gut, pool_std_ebooks:
            work.license_pools.append(p)

        work.calculate_presentation()

        return (work, pool_std_ebooks, pool_git, pool_gut,
            edition_std_ebooks, edition_git, edition_gut, alice, bob)


    def print_database_instance(self):
        """
        Calls the class method that examines the current state of the database model
        (whether it's been committed or not).

        NOTE:  If you set_trace, and hit "continue", you'll start seeing console output right
        away, without waiting for the whole test to run and the standard output section to display.
        You can also use nosetest --nocapture.
        I use:
        def test_name(self):
            [code...]
            set_trace()
            self.print_database_instance()  # TODO: remove before prod
            [code...]
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.")
            return

        DatabaseTest.print_database_class(self._db)
        return


    @classmethod
    def print_database_class(cls, db_connection):
        """
        Prints to the console the entire contents of the database, as the unit test sees it.
        Exists because unit tests don't persist db information, they create a memory
        representation of the db state, and then roll the unit test-derived transactions back.
        So we cannot see what's going on by going into postgres and running selects.
        This is the in-test alternative to going into postgres.

        Can be called from model and metadata classes as well as tests.

        NOTE: The purpose of this method is for debugging.
        Be careful of leaving it in code and potentially outputting
        vast tracts of data into your output stream on production.

        Call like this:
        set_trace()
        from testing import (
            DatabaseTest,
        )
        _db = Session.object_session(self)
        DatabaseTest.print_database_class(_db)  # TODO: remove before prod
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.")
            return

        works = db_connection.query(Work).all()
        identifiers = db_connection.query(Identifier).all()
        license_pools = db_connection.query(LicensePool).all()
        editions = db_connection.query(Edition).all()
        data_sources = db_connection.query(DataSource).all()
        representations = db_connection.query(Representation).all()

        if (not works):
            print "NO Work found"
        for wCount, work in enumerate(works):
            # pipe character at end of line helps see whitespace issues
            print "Work[%s]=%s|" % (wCount, work)

            if (not work.license_pools):
                print "    NO Work.LicensePool found"
            for lpCount, license_pool in enumerate(work.license_pools):
                print "    Work.LicensePool[%s]=%s|" % (lpCount, license_pool)

            print "    Work.presentation_edition=%s|" % work.presentation_edition

        print "__________________________________________________________________\n"
        if (not identifiers):
            print "NO Identifier found"
        for iCount, identifier in enumerate(identifiers):
            print "Identifier[%s]=%s|" % (iCount, identifier)
            print "    Identifier.licensed_through=%s|" % identifier.licensed_through

        print "__________________________________________________________________\n"
        if (not license_pools):
            print "NO LicensePool found"
        for index, license_pool in enumerate(license_pools):
            print "LicensePool[%s]=%s|" % (index, license_pool)
            print "    LicensePool.work_id=%s|" % license_pool.work_id
            print "    LicensePool.data_source_id=%s|" % license_pool.data_source_id
            print "    LicensePool.identifier_id=%s|" % license_pool.identifier_id
            print "    LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id
            print "    LicensePool.superceded=%s|" % license_pool.superceded
            print "    LicensePool.suppressed=%s|" % license_pool.suppressed

        print "__________________________________________________________________\n"
        if (not editions):
            print "NO Edition found"
        for index, edition in enumerate(editions):
            # pipe character at end of line helps see whitespace issues
            print "Edition[%s]=%s|" % (index, edition)
            print "    Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
            print "    Edition.permanent_work_id=%s|" % edition.permanent_work_id
            if (edition.data_source):
                print "    Edition.data_source.id=%s|" % edition.data_source.id
                print "    Edition.data_source.name=%s|" % edition.data_source.name
            else:
                print "    No Edition.data_source."
            if (edition.license_pool):
                print "    Edition.license_pool.id=%s|" % edition.license_pool.id
            else:
                print "    No Edition.license_pool."

            print "    Edition.title=%s|" % edition.title
            print "    Edition.author=%s|" % edition.author
            if (not edition.author_contributors):
                print "    NO Edition.author_contributor found"
            for acCount, author_contributor in enumerate(edition.author_contributors):
                print "    Edition.author_contributor[%s]=%s|" % (acCount, author_contributor)

        print "__________________________________________________________________\n"
        if (not data_sources):
            print "NO DataSource found"
        for index, data_source in enumerate(data_sources):
            print "DataSource[%s]=%s|" % (index, data_source)
            print "    DataSource.id=%s|" % data_source.id
            print "    DataSource.name=%s|" % data_source.name
            print "    DataSource.offers_licenses=%s|" % data_source.offers_licenses
            print "    DataSource.editions=%s|" % data_source.editions
            print "    DataSource.license_pools=%s|" % data_source.license_pools
            print "    DataSource.links=%s|" % data_source.links

        print "__________________________________________________________________\n"
        if (not representations):
            print "NO Representation found"
        for index, representation in enumerate(representations):
            print "Representation[%s]=%s|" % (index, representation)
            print "    Representation.id=%s|" % representation.id
            print "    Representation.url=%s|" % representation.url
            print "    Representation.mirror_url=%s|" % representation.mirror_url
            print "    Representation.fetch_exception=%s|" % representation.fetch_exception
            print "    Representation.mirror_exception=%s|" % representation.mirror_exception

        return


    def _library(self, name=None, short_name=None):
        name=name or self._str
        short_name = short_name or self._str
        library, ignore = get_one_or_create(
            self._db, Library, name=name, short_name=short_name,
            create_method_kwargs=dict(uuid=str(uuid.uuid4())),
        )
        return library

    def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT,
                    external_account_id=None, url=None, username=None,
                    password=None, data_source_name=None):
        name = name or self._str
        collection, ignore = get_one_or_create(
            self._db, Collection, name=name
        )
        collection.external_account_id = external_account_id
        integration = collection.create_external_integration(protocol)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        integration.url = url
        integration.username = username
        integration.password = password

        if data_source_name:
            collection.data_source = data_source_name
        return collection

    @property
    def _default_library(self):
        """A Library that will only be created once throughout a given test.

        By default, the `_default_collection` will be associated with
        the default library.
        """
        if not hasattr(self, '_default__library'):
            self._default__library = self.make_default_library(self._db)
        return self._default__library

    @property
    def _default_collection(self):
        """A Collection that will only be created once throughout
        a given test.

        For most tests there's no need to create a different
        Collection for every LicensePool. Using
        self._default_collection instead of calling self.collection()
        saves time.
        """
        if not hasattr(self, '_default__collection'):
            self._default__collection = self._default_library.collections[0]
        return self._default__collection

    @classmethod
    def make_default_library(cls, _db):
        """Ensure that the default library exists in the given database.

        This can be called by code intended for use in testing but not actually
        within a DatabaseTest subclass.
        """
        library, ignore = get_one_or_create(
            _db, Library, create_method_kwargs=dict(
                uuid=unicode(uuid.uuid4()),
                name="default",
            ), short_name="default"
        )
        collection, ignore = get_one_or_create(
            _db, Collection, name="Default Collection"
        )
        integration = collection.create_external_integration(
            ExternalIntegration.OPDS_IMPORT
        )
        integration.goal = ExternalIntegration.LICENSE_GOAL
        if collection not in library.collections:
            library.collections.append(collection)
        return library

    def _catalog(self, name=u"Faketown Public Library"):
        source, ignore = get_one_or_create(self._db, DataSource, name=name)

    def _integration_client(self, url=None, shared_secret=None):
        url = url or self._url
        secret = shared_secret or u"secret"
        return get_one_or_create(
            self._db, IntegrationClient, shared_secret=secret,
            create_method_kwargs=dict(url=url)
        )[0]

    def _subject(self, type, identifier):
        return get_one_or_create(
            self._db, Subject, type=type, identifier=identifier
        )[0]

    def _classification(self, identifier, subject, data_source, weight=1):
        return get_one_or_create(
            self._db, Classification, identifier=identifier, subject=subject,
            data_source=data_source, weight=weight
        )[0]

    def sample_cover_path(self, name):
        """The path to the sample cover with the given filename."""
        base_path = os.path.split(__file__)[0]
        resource_path = os.path.join(base_path, "tests", "files", "covers")
        sample_cover_path = os.path.join(resource_path, name)
        return sample_cover_path

    def sample_cover_representation(self, name):
        """A Representation of the sample cover with the given filename."""
        sample_cover_path = self.sample_cover_path(name)
        return self._representation(
            media_type="image/png", content=open(sample_cover_path).read())[0]
コード例 #48
0
ファイル: query_caching.py プロジェクト: gajop/ailadder
    
    Base = declarative_base(engine=create_engine('sqlite://', echo=True))
    
    class User(Base):
        __tablename__ = 'users'
        id = Column(Integer, primary_key=True)
        name = Column(String(100))
        
        def __repr__(self):
            return "User(name=%r)" % self.name

    Base.metadata.create_all()
    
    sess = Session()
    
    sess.add_all(
        [User(name='u1'), User(name='u2'), User(name='u3')]
    )
    sess.commit()
    
    # cache two user objects
    sess.query(User).with_cache_key('u2andu3').filter(User.name.in_(['u2', 'u3'])).all()
    
    sess.close()
    
    sess = Session()
    
    # pull straight from cache
    print sess.query(User).with_cache_key('u2andu3').all()
    
コード例 #49
0
ファイル: types.py プロジェクト: SmartTeleMax/iktomi
class TypeDecoratorsTest(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine("sqlite://")
        Base.metadata.create_all(self.engine)
        self.db = Session(bind=self.engine)

    def tearDown(self):
        self.db.query(TypesObject).delete()
        self.db.commit()
        self.db.close()

    def test_string_list(self):
        words = ['one', 'two', 'three', 'four', 'five']
        obj = TypesObject()
        obj.words = words
        self.db.add(obj)
        self.db.commit()

        self.db.close()
        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertEqual(words, obj.words)

    def test_integer_list(self):
        numbers = [1, 5, 10, 15, 20]
        obj = TypesObject()
        obj.numbers = numbers
        self.db.add(obj)
        self.db.commit()

        self.db.close()
        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertEqual(numbers, obj.numbers)

    def test_string_wrapped_in_html(self):
        obj = TypesObject()
        obj.html_string1 = Markupable('<html>value</html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_string1, Markup)
        self.assertEqual('<html>value</html>', obj.html_string1)

    def test_html_string(self):
        obj = TypesObject()
        obj.html_string2 = Markupable('<html>value</html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_string2, Markup)
        self.assertEqual('<html>value</html>', obj.html_string2)

    def test_html_text(self):
        obj = TypesObject()
        text = "<html>" + "the sample_text " * 100 + "</html>"
        obj.html_text = Markupable(text)
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_text, Markup)
        self.assertEqual(text, obj.html_text)

    def test_html_custom_markup(self):
        obj = TypesObject()
        obj.html_custom = Markupable('<html>   value   </html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_custom, CustomMarkup)
        self.assertEqual('<html> value </html>', obj.html_custom)
コード例 #50
0
class DatabaseTest(object):
    def setUp(self):
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
コード例 #51
0
ファイル: books_example.py プロジェクト: smartkiwi/pugip-demo

session = Session(bind=engine)
q = session.query(Book).filter(Book.title == "Essential SQLAlchemy")
print q
book = q.one()
print (book.id, book.title)


author = Author(name="Rick Copeland")
author.books.append(book)
session.add(book)
session.flush()

####
# select CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END is_novel, count(*)
# from BOOK
# group by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END
# order by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END
#
is_novel_column = case([(Book.pages_count > 200, 1)], else_=0)
novel_query = (
    session.query(is_novel_column.label("is_alias"), count()).group_by(is_novel_column).order_by(is_novel_column)
)

print novel_query
print novel_query.all()


session.close()
コード例 #52
0
ファイル: session.py プロジェクト: GEverding/inbox
class InboxSession(object):
    """ Inbox custom ORM (with SQLAlchemy compatible API).

    Parameters
    ----------
    engine : <sqlalchemy.engine.Engine>
        A configured database engine to use for this session
    versioned : bool
        Do you want to enable the transaction log?
    ignore_soft_deletes : bool
        Whether or not to ignore soft-deleted objects in query results.
    namespace_id : int
        Namespace to limit query results with.
    """
    def __init__(self, engine, versioned=True, ignore_soft_deletes=True,
                 namespace_id=None):
        # TODO: support limiting on namespaces
        assert engine, "Must set the database engine"

        args = dict(bind=engine, autoflush=True, autocommit=False)
        self.ignore_soft_deletes = ignore_soft_deletes
        if ignore_soft_deletes:
            args['query_cls'] = InboxQuery
        self._session = Session(**args)

        if versioned:
            from inbox.models.transaction import create_revisions

            @event.listens_for(self._session, 'after_flush')
            def after_flush(session, flush_context):
                """
                Hook to log revision snapshots. Must be post-flush in order to
                grab object IDs on new objects.
                """
                create_revisions(session)

    def query(self, *args, **kwargs):
        q = self._session.query(*args, **kwargs)
        if self.ignore_soft_deletes:
            return q.options(IgnoreSoftDeletesOption())
        else:
            return q

    def add(self, instance):
        if not self.ignore_soft_deletes or not instance.is_deleted:
            self._session.add(instance)
        else:
            raise Exception("Why are you adding a deleted object?")

    def add_all(self, instances):
        if True not in [i.is_deleted for i in instances] or \
                not self.ignore_soft_deletes:
            self._session.add_all(instances)
        else:
            raise Exception("Why are you adding a deleted object?")

    def delete(self, instance):
        if self.ignore_soft_deletes:
            instance.mark_deleted()
            # just to make sure
            self._session.add(instance)
        else:
            self._session.delete(instance)

    def begin(self):
        self._session.begin()

    def commit(self):
        self._session.commit()

    def rollback(self):
        self._session.rollback()

    def flush(self):
        self._session.flush()

    def close(self):
        self._session.close()

    def expunge(self, obj):
        self._session.expunge(obj)

    def merge(self, obj):
        return self._session.merge(obj)

    @property
    def no_autoflush(self):
        return self._session.no_autoflush