コード例 #1
0
def upgrade():
    from sqlalchemy.orm.session import Session
    session = Session(bind=op.get_bind())

    # Find duplicates
    for vcs in session.query(Vcs).group_by(Vcs.repository_id, Vcs.revision).having(func.count(Vcs.id) > 1).all():
        print(vcs)
        # Find all vcs entries with this duplication
        dupes = session.query(Vcs).filter(Vcs.repository_id == vcs.repository_id).filter(Vcs.revision == vcs.revision).all()
        # Keep the first and remove the others - thus we need to update references to others to the first
        for update in dupes[1:]:
            for af in session.query(Artifakt).filter(Artifakt.vcs_id == update.id).all():
                print("Updating artifakt {} to point to vcs {}".format(af.sha1, dupes[0].id))
                af.vcs_id = dupes[0].id
            print("Deleting vcs  {}".format(update.id))
            session.delete(update)
    session.commit()

    if session.bind.dialect.name == "sqlite":
        session.execute("PRAGMA foreign_keys = OFF")
    elif session.bind.dialect.name == "mysql":
        session.execute("SET foreign_key_checks = 0")
    else:
        raise NotImplemented

    with op.batch_alter_table('vcs', schema=None) as batch_op:
        batch_op.create_unique_constraint('rr', ['repository_id', 'revision'])

    if session.bind.dialect.name == "sqlite":
        session.execute("PRAGMA foreign_keys = ON")
    elif session.bind.dialect.name == "mysql":
        session.execute("SET foreign_key_checks = 1")
コード例 #2
0
ファイル: crawl.py プロジェクト: Sentimentron/sentropy
class DBBackedController(object):

    def __init__(self, engine, session=None):

        if type(engine) == types.StringType:
            logging.info("Using connection string '%s'" % (engine,))
            new_engine = create_engine(engine, encoding='utf-8')
            if "sqlite:" in engine:
                logging.debug("Setting text factory for unicode compat.")
                new_engine.raw_connection().connection.text_factory = str 
            self._engine = new_engine

        else:
            logging.info("Using existing engine...")
            self._engine = engine

        if session is None:
            logging.info("Binding session...")
            self._session = Session(bind=self._engine)
        else:
            self._session = session

        logging.info("Updating metadata...")
        Base.metadata.create_all(self._engine)

    def commit(self):
        logging.info("Commiting...")
        self._session.commit()
コード例 #3
0
ファイル: manager.py プロジェクト: BtbN/esaupload
def init_settings(session: Session):
    set_new_setting(session, "ffmpeg_path", "ffmpeg")
    set_new_setting(session, "ffmpeg_verbose", False)
    set_new_setting(session, "input_directory", "/set/me")
    set_new_setting(session, "output_directory", "esa_videos")
    set_new_setting(session, "final_output_directory", "esa_final_videos")
    set_new_setting(session, "common_keywords", "ESA")
    set_new_setting(session, "input_video_extension", ".flv")
    set_new_setting(session, "output_video_extension", ".mp4")
    set_new_setting(session, "output_video_prefix", "esa_")
    set_new_setting(session, "upload_logfile", "esa_upload.log")
    set_new_setting(session, "youtube_public", False)
    set_new_setting(session, "youtube_schedule", False)
    set_new_setting(session, "youtube_schedule_start_ts", 0)
    set_new_setting(session, "schedule_offset_hours", 12)
    set_new_setting(session, "upload_to_twitch", True)
    set_new_setting(session, "upload_to_youtube", True)
    set_new_setting(session, "delete_after_upload", True)
    set_new_setting(session, "upload_ratelimit", 2.5)
    set_new_setting(session, "upload_retry_count", 3)
    set_new_setting(session, "twitch_ul_pid", 0)
    set_new_setting(session, "twitch_ul_create_time", 0.0)
    set_new_setting(session, "youtube_ul_pid", 0)
    set_new_setting(session, "youtube_ul_create_time", 0.0)
    set_new_setting(session, "proc_ul_pid", 0)
    set_new_setting(session, "proc_ul_create_time", 0.0)
    set_new_setting(session, "init_proc_status", "")
    session.commit()
コード例 #4
0
ファイル: manager.py プロジェクト: BtbN/esaupload
def post_process_video(session: Session, debug):
    vids = session.query(Video)
    vids = vids.filter(Video.processing_done)
    vids = vids.filter(not_(Video.post_processing_done))
    vids = vids.filter(or_(and_(Video.do_twitch, Video.done_twitch), not_(Video.do_twitch)))
    vids = vids.filter(or_(and_(Video.do_youtube, Video.done_youtube), not_(Video.do_youtube)))
    vids = vids.order_by(Video.id.asc())
    vid = vids.first()

    if not vid:
        print("No video in need of processing found")
        return 0

    out_dir = get_setting(session, "output_directory")
    final_dir = get_setting(session, "final_output_directory")
    out_prefix = get_setting(session, "output_video_prefix")
    out_ext = get_setting(session, "output_video_extension")

    out_fname = "%s%s%s" % (out_prefix, vid.id, out_ext)
    out_path = os.path.join(out_dir, out_fname)
    final_path = os.path.join(final_dir, out_fname)

    if out_path != final_path:
        shutil.move(out_path, final_path)
        vid.post_processing_status = "Moved %s to %s" % (out_path, final_path)
    else:
        vid.post_processing_status = "Nothing to do"

    vid.post_processing_done = True
    session.commit()

    return 0
コード例 #5
0
class TestSessions(TestCase):
    plugins = []

    def test_multiple_connections(self):
        self.session2 = Session(bind=self.engine.connect())
        article = self.Article(name=u'Session1 article')
        article2 = self.Article(name=u'Session2 article')
        self.session.add(article)
        self.session2.add(article2)
        self.session.flush()
        self.session2.flush()

        self.session.commit()
        self.session2.commit()
        assert article.versions[-1].transaction_id == 1
        assert article2.versions[-1].transaction_id == 2

    def test_manual_transaction_creation(self):
        uow = versioning_manager.unit_of_work(self.session)
        transaction = uow.create_transaction(self.session)
        self.session.flush()
        assert transaction.id == 1
        article = self.Article(name=u'Session1 article')
        self.session.add(article)
        self.session.flush()
        assert uow.current_transaction.id == 1

        self.session.commit()
        assert article.versions[-1].transaction_id == 1

    def test_commit_without_objects(self):
        self.session.commit()
コード例 #6
0
ファイル: interface.py プロジェクト: AlexWylie/marcotti
 def create_session(self):
     session = Session(self.connection)
     try:
         yield session
         session.commit()
     except Exception as ex:
         session.rollback()
         raise ex
     finally:
         session.close()
コード例 #7
0
ファイル: vmdatabase.py プロジェクト: bhuvan/devstack-gate
class VMDatabase(object):
    def __init__(self, path=os.path.expanduser("~/vm.db")):
        engine = create_engine('sqlite:///%s' % path, echo=False)
        metadata.create_all(engine)
        Session = sessionmaker(bind=engine, autoflush=True, autocommit=False)
        self.session = Session()

    def print_state(self):
        for provider in self.getProviders():
            print 'Provider:', provider.name
            for base_image in provider.base_images:
                print '  Base image:', base_image.name
                for snapshot_image in base_image.snapshot_images:
                    print '    Snapshot:', snapshot_image.name, \
                          snapshot_image.state
                for machine in base_image.machines:
                    print '    Machine:', machine.id, machine.name, \
                          machine.state, machine.state_time, machine.ip

    def abort(self):
        self.session.rollback()

    def commit(self):
        self.session.commit()

    def delete(self, obj):
        self.session.delete(obj)

    def getProviders(self):
        return self.session.query(Provider).all()

    def getProvider(self, name):
        return self.session.query(Provider).filter_by(name=name)[0]

    def getResult(self, id):
        return self.session.query(Result).filter_by(id=id)[0]

    def getMachine(self, id):
        return self.session.query(Machine).filter_by(id=id)[0]

    def getMachineByJenkinsName(self, name):
        return self.session.query(Machine).filter_by(jenkins_name=name)[0]

    def getMachineForUse(self, image_name):
        """Atomically find a machine that is ready for use, and update
        its state."""
        image = None
        for machine in self.session.query(Machine).filter(
            machine_table.c.state == READY).order_by(
            machine_table.c.state_time):
            if machine.base_image.name == image_name:
                machine.state = USED
                self.commit()
                return machine
        raise Exception("No machine found for image %s" % image_name)
コード例 #8
0
def upgrade():
    context = op.get_context()
    session = Session()
    session.bind = context.bind
    for tp in session.query(TimeEntry).filter_by(customer_request_id=None):
        for trac in tp.project.tracs:
            cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone()
            sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone()
            tp.customer_request_id = sql_cr.id
            print sql_cr.id
    session.commit()
コード例 #9
0
def upgrade():
    ### commands auto generated by Alembic - please adjust! ###
    op.add_column('time_entries', sa.Column('tickettitle', sa.Unicode(), nullable=True))
    context = op.get_context()
    session = Session()
    session.bind = context.bind

    for tp in session.query(TimeEntry):
        for trac in tp.project.tracs:
            ticket = session.execute('select summary from "trac_%s".ticket where id=%s' % (trac.trac_name, tp.ticket)).fetchone()
            tp.tickettitle = ticket.summary
    session.commit()
コード例 #10
0
ファイル: mysql.py プロジェクト: liuzelei/walis
 def wrapper(*args, **kwargs):
     ret = func(*args, **kwargs)
     session = Session()
     # tmp
     session._model_changes = {}
     try:
         session.commit()
     except SQLAlchemyError as se:
         session.rollback()
         raise_server_exc(DATABASE_UNKNOWN_ERROR, exc=se)
     finally:
         session.close()
     return ret
コード例 #11
0
def upgrade():
    context = op.get_context()
    session = Session()
    session.bind = context.bind
    session.execute(
    """ALTER TABLE applications DROP CONSTRAINT "applications_project_id_fkey", ADD CONSTRAINT "applications_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;
       ALTER TABLE customer_requests DROP CONSTRAINT "customer_requests_project_id_fkey", ADD CONSTRAINT "customer_requests_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;
       ALTER TABLE groups DROP CONSTRAINT "groups_project_id_fkey", ADD CONSTRAINT "groups_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;
       ALTER TABLE contracts ADD CONSTRAINT "contracts_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;
       ALTER TABLE kanban_projects DROP CONSTRAINT "kanban_projects_project_id_fkey", ADD CONSTRAINT "kanban_projects_project_id_fkey" foreign key (project_id) references projects(id) on update cascade on delete cascade;
       ALTER TABLE favorite_projects DROP CONSTRAINT "favorite_projects_project_id_fkey", ADD CONSTRAINT "favorite_projects_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;
       ALTER TABLE time_entries DROP CONSTRAINT "time_entries_project_id_fkey", ADD CONSTRAINT "time_entries_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;""")
    session.commit()
コード例 #12
0
ファイル: manager.py プロジェクト: BtbN/esaupload
def delete_video(session: Session, vid_id: int):
    vid = session.query(Video).get(vid_id)

    if not vid:
        print("Video ID not found")
        return -1

    print("Deleting video %s: %s" % (vid.id, vid.title))

    session.delete(vid)
    session.commit()

    print("Done")

    return 0
コード例 #13
0
ファイル: interface.py プロジェクト: soccermetrics/marcotti
    def create_session(self):
        """
        Create a session context that communicates with the database.

        Commits all changes to the database before closing the session, and if an exception is raised,
        rollback the session.
        """
        session = Session(self.connection)
        try:
            yield session
            session.commit()
        except Exception as ex:
            session.rollback()
            raise ex
        finally:
            session.close()
コード例 #14
0
def upgrade():
    context = op.get_context()
    session = Session()
    session.bind = context.bind

    for b in session.query(KanbanBoard):
        if b.json:
            columns = loads(b.json)
            for n, column in enumerate(columns):
                for m, task in enumerate(column['tasks']):
                    new_id = update_id(task['id'])
                    print task['id'], new_id
                    if new_id != task['id']:
                        columns[n]['tasks'][m]['id'] = new_id
                        b.json = dumps(columns)
    session.commit()
コード例 #15
0
def upgrade():
    ### commands auto generated by Alembic - please adjust! ###
    op.add_column('time_entries', sa.Column('customer_request_id', sa.String(), nullable=True))
    op.drop_column('time_entries', u'contract_id')
    op.drop_column('customer_requests', u'placement')

    context = op.get_context()
    session = Session()
    session.bind = context.bind

    for tp in session.query(TimeEntry):
        for trac in tp.project.tracs:
            cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone()
            sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone()
            tp.customer_request_id = sql_cr.id
    session.commit()
コード例 #16
0
ファイル: datasets.py プロジェクト: paulofreitas/dtb-ibge
    def transaction(cls,
                    session: db_session.Session) -> Iterator[db_session.Session]:
        """
        Provides a transactional context-based database session.

        Args:
            session: The database session instance to wrap

        Yields:
            The wrapped database session instance
        """
        try:
            yield session
            session.commit()
        except:
            session.rollback()
            raise
        finally:
            session.close()
コード例 #17
0
ファイル: dbhelper.py プロジェクト: ericbean/RecordSheet
def setup_module():
    global transaction, connection, engine

    engine = create_engine('postgresql:///recordsheet_test')
    connection = engine.connect()
    transaction = connection.begin()
    Base.metadata.create_all(connection)

    #insert some data
    inner_tr = connection.begin_nested()
    ses = Session(connection)
#    ses.begin_nested()
    ses.add(dbmodel.Account(name='TEST01', desc='test account 01'))
    ses.add(dbmodel.Account(name='TEST02', desc='test account 02'))
    user = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=False)
    ses.add(user)
    lockeduser = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=True)
    ses.add(lockeduser)

    batch = dbmodel.Batch(user=user)
    ses.add(batch)
    jrnl = dbmodel.Journal(memo='test', batch=batch,
                datetime='2016-06-05 14:09:00-05')
    ses.add(jrnl)
    ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1,
                        journal=jrnl))
    ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2,
                        journal=jrnl))
    ses.commit()
    # mock a sessionmaker so all querys are in this transaction
    dbapi._session = lambda: ses
    ses.begin_nested()

    @event.listens_for(ses, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            ses.begin_nested()
コード例 #18
0
ファイル: simple.py プロジェクト: devhub/baph
    def setup_databases(self, **kwargs):
        # import all models to populate orm.metadata
        for app in settings.INSTALLED_APPS:
            import_any_module(['%s.models' % app], raise_error=False)

        # determine which schemas we need
        default_schema = orm.engine.url.database
        schemas = set(t.schema or default_schema \
            for t in Base.metadata.tables.values())

        url = deepcopy(orm.engine.url)
        url.database = None
        self.engine = create_engine(url)
        insp = inspect(self.engine)

        # get a list of already-existing schemas
        existing_schemas = set(insp.get_schema_names())

        # if any of the needed schemas exist, do not proceed
        conflicts = schemas.intersection(existing_schemas)
        if conflicts:
            for c in conflicts:
                print 'drop schema %s;' % c
            sys.exit('The following schemas are already present: %s. ' \
                'TestRunner cannot proceeed' % ','.join(conflicts))
        
        # create schemas
        session = Session(bind=self.engine)
        for schema in schemas:
            session.execute(CreateSchema(schema))
        session.commit()
        session.bind.dispose()

        # create tables
        if len(orm.Base.metadata.tables) > 0:
            orm.Base.metadata.create_all(checkfirst=False)

        # generate permissions
        call_command('createpermissions')

        return schemas
コード例 #19
0
ファイル: base.py プロジェクト: soccermetrics/marcotti-mls
    def create_session(self):
        """
        Create a session context that communicates with the database.

        Commits all changes to the database before closing the session, and if an exception is raised,
        rollback the session.
        """
        session = Session(self.connection)
        logger.info("Create session {0} with {1}".format(
            id(session), self._public_db_uri(str(self.engine.url))))
        try:
            yield session
            session.commit()
            logger.info("Commit transactions to database")
        except Exception:
            session.rollback()
            logger.exception("Database transactions rolled back")
        finally:
            logger.info("Session {0} with {1} closed".format(
                id(session), self._public_db_uri(str(self.engine.url))))
            session.close()
コード例 #20
0
def upgrade(migrate_engine):
    """
    Upgrade operations go here.
    Don't create your own engine; bind migrate_engine to your metadata
    """
    #==========================================================================
    # USER LOGS
    #==========================================================================
    from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog
    tbl = UserLog.__table__
    username = Column("username", String(255, convert_unicode=False,
                                         assert_unicode=None), nullable=True,
                      unique=None, default=None)
    # create username column
    username.create(table=tbl)

    _Session = Session()
    ## after adding that column fix all usernames
    users_log = _Session.query(UserLog)\
            .options(joinedload(UserLog.user))\
            .options(joinedload(UserLog.repository)).all()

    for entry in users_log:
        entry.username = entry.user.username
        _Session.add(entry)
    _Session.commit()

    #alter username to not null
    from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog
    tbl_name = UserLog.__tablename__
    tbl = Table(tbl_name,
                MetaData(bind=migrate_engine), autoload=True,
                autoload_with=migrate_engine)
    col = tbl.columns.username

    # remove nullability from revision field
    col.alter(nullable=False)
コード例 #21
0
def upgrade():
    bind = op.get_bind()
    session = Session(bind=bind)

    session.query(ConfigModel).filter(
        ConfigModel.group == 'station',
        ConfigModel.name == 'required_fields').update({
            'value':
            pickle.dumps([
                'code', 'start_date', 'latitude', 'longitude', 'creation_date',
                'elevation'
            ])
        })
    session.query(ConfigModel).filter(
        ConfigModel.group == 'channel',
        ConfigModel.name == 'required_fields').update({
            'value':
            pickle.dumps([
                'code', 'start_date', 'latitude', 'longitude', 'location_code',
                'elevation', 'depth'
            ])
        })

    session.commit()
コード例 #22
0
    async def create(self, db: Session, *, order_in: OrderCreate) -> Order:

        new_order = Order(business_id=order_in.business_id)

        db.add(new_order)
        db.commit()
        db.flush(new_order)

        item_orders: List[ItemOrder] = []

        for item in order_in.items_in_order:

            item_db = await crud.item.get(db, id=item.item_id)

            if not item_db:
                raise HTTPException(status_code=404,
                                    detail=f"Item {item.item_id} not found")

            # Create a new relation ItemOrder

            item_order = ItemOrder(
                tot_price=item.tot_price,
                quantity=item.quantity,
                order_id=new_order.id,
                item_id=item_db.id,
            )

            item_orders.append(item_order)

        new_order.items = item_orders

        db.add(new_order)
        db.commit()
        db.flush(new_order)

        return new_order
コード例 #23
0
    def create(session: Session,
               state: TaskState,
               timestamp: datetime,
               repo_dir: str,
               task_type: TaskType,
               params: dict = None) -> Task:
        """ Creates a new Task record.

        Args:
            session (Session): The database session.
            state (TaskState): The task state.
            timestamp (datetime): The time when the task was created.
            repo_dir (str): The task repository.
            task_type (TaskType): The task type.
            params (str, optional): The task parameters. Defaults to None.

        Raises:
            ValueError: Thrown when missing status, timestamp or repo_dir.
            TaskAlreadyExistsError: Thrown when already exists a task with the same task_id.

        Returns:
            Task: The task.
        """
        if state is None or not timestamp or not repo_dir or task_type is None:
            raise ValueError(
                "You cannot create a task without a status, a timestamp and a repo_dir."
            )
        try:
            params_1: str = json.dumps(params)
            task: Task = Task(state.value, timestamp, repo_dir,
                              task_type.value, params_1)
            session.add(task)
            session.commit()
            return task
        except IntegrityError as err:
            raise TaskAlreadyExistsError from err
コード例 #24
0
    def update(
        cls,
        session: Session,
        id: int,
        params: typing.Dict[str, typing.Any],
        commit: bool = True,
        autoflush: bool = False,
    ) -> DeclarativeMeta:
        obj = cls.get(session, id)

        if obj:
            for param in params:

                setattr(obj, param, params[param])

            if commit:
                session.commit()

                session.refresh(obj)

            if autoflush:
                session.flush()

            return obj
コード例 #25
0
 def create_comment(cls,
                    praw_comment: PrawComment,
                    post: Post,
                    session: Session,
                    download_session_id: int,
                    parent_comment_id: Optional[int] = None):
     if cls.check_duplicate_comment(praw_comment.id, session):
         author = cls.get_author(praw_comment, session)
         subreddit = cls.get_subreddit(praw_comment, session)
         comment = Comment(author=author,
                           subreddit=subreddit,
                           post=post,
                           reddit_id=praw_comment.id,
                           body=praw_comment.body,
                           body_html=praw_comment.body_html,
                           score=praw_comment.score,
                           date_posted=datetime.fromtimestamp(
                               praw_comment.created),
                           parent_id=parent_comment_id,
                           download_session_id=download_session_id)
         session.add(comment)
         session.commit()
         return comment
     return None
コード例 #26
0
ファイル: chat.py プロジェクト: Ariento89/digirent-api
 async def send_message_to_user(self,
                                user_id: UUID,
                                sender_id: UUID,
                                message: str,
                                session: Session = None):
     db_message = ChatMessage(from_user_id=sender_id,
                              to_user_id=user_id,
                              message=message)
     session.add(db_message)
     session.commit()
     from_chatuser = self.chat_users.get(sender_id)
     to_chatuser = self.chat_users.get(user_id)
     event = ChatEvent(
         event_type=ChatEventType.MESSAGE,
         data={
             "from": str(sender_id),
             "to": str(user_id),
             "message": message
         },
     )
     if from_chatuser:
         await from_chatuser.send_json(event.dict(by_alias=True))
     if to_chatuser:
         await to_chatuser.send_json(event.dict(by_alias=True))
コード例 #27
0
    def updateHardwareProfile(self, session: Session,
                              hardwareProfileObject: HardwareProfile) -> None:
        """
        Update Hardware Profile Object
        """

        try:
            dbHardwareProfile = \
                self._hardwareProfilesDbHandler.getHardwareProfileById(
                    session, hardwareProfileObject.getId())

            self.__populateHardwareProfile(session, hardwareProfileObject,
                                           dbHardwareProfile)
            self._set_tags(dbHardwareProfile, hardwareProfileObject.getTags())
            session.commit()

        except TortugaException:
            session.rollback()
            raise

        except Exception as ex:
            session.rollback()
            self.getLogger().exception('%s' % ex)
            raise
コード例 #28
0
 def test_check_auth_redis_miss_wrong_auth_id(self):
     db_handler = self.auth_handler.db_handler
     db_conn = db_handler.getEngine().connect()
     db_txn = db_conn.begin()
     try:
         db_session = Session(bind=db_conn)
         try:
             account = Account(auth_id='some_auth_id', username='******')
             db_session.add(account)
             db_session.flush()
             phonenumber = PhoneNumber(number='9740171794',
                                       account_id=account.id)
             db_session.add(phonenumber)
             db_session.commit()
             self.auth_handler.redis_client.hget.return_value = None
             status, phonenums = self.auth_handler._check_auth(
                 'some_user', 'faulty_auth_id', db_session)
             self.assertFalse(status)
             self.assertEquals(None, phonenums)
         finally:
             db_session.close()
     finally:
         db_txn.rollback()
         db_conn.close()
コード例 #29
0
def migrate_archived_page_urls(session: Session, cur, table_archives: str):
    t0 = time.time()
    logger.info('Migrating %s', table_archives)
    n_in = 0
    n_out = 0
    cur.execute(
        f'SELECT feed_url, archived_url, date, inserted FROM {table_archives}')
    for feed_url, archived_url, date, inserted in cur:
        n_in += 1
        logger.debug('%d %s', n_in, archived_url)
        if not count(
                session.query(ArchivedPageURL).filter(
                    ArchivedPageURL.feed_url == feed_url,
                    ArchivedPageURL.archived_url == archived_url,
                    ArchivedPageURL.date == date,
                )):
            archived_page_url = ArchivedPageURL(
                feed_url=feed_url,
                archived_url=archived_url,
                date=date,
                inserted=inserted,
            )
            session.add(archived_page_url)
            if n_in % 10000 == 0:
                logger.info('%d flush', n_in)
                session.flush()
            n_out += 1
    logger.info('commit')
    session.commit()
    logger.info(
        'Migrated %s: %d -> %d in %ds',
        table_archives,
        n_in,
        n_out,
        time.time() - t0,
    )
コード例 #30
0
def upgrade():
    from sqlalchemy.orm.session import Session
    session = Session(bind=op.get_bind())

    # Add dummy name where there is none
    for repo in session.query(Repository).all():
        if repo.name == "":
            repo.name = "NoName"
    session.commit()

    if session.bind.dialect.name == "sqlite":
        session.execute("PRAGMA foreign_keys = OFF")
    elif session.bind.dialect.name == "mysql":
        session.execute("SET foreign_key_checks = 0")
    else:
        raise NotImplemented

    with op.batch_alter_table('repository', schema=None) as batch_op:
        batch_op.create_check_constraint('non_empty_name', 'name != ""')

    if session.bind.dialect.name == "sqlite":
        session.execute("PRAGMA foreign_keys = ON")
    elif session.bind.dialect.name == "mysql":
        session.execute("SET foreign_key_checks = 1")
コード例 #31
0
ファイル: 4e3738cdc34c_.py プロジェクト: mathbou/zou
def upgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    op.add_column(
        "project_task_type_link",
        sa.Column("created_at", sa.DateTime(), nullable=True),
    )
    op.add_column(
        "project_task_type_link",
        sa.Column(
            "id",
            sqlalchemy_utils.types.uuid.UUIDType(binary=False),
            default=uuid.uuid4,
            nullable=True,
        ),
    )
    op.execute("UPDATE project_task_type_link SET id = '%s'" %
               fields.gen_uuid())
    op.add_column(
        "project_task_type_link",
        sa.Column("updated_at", sa.DateTime(), nullable=True),
    )
    op.add_column(
        "project_task_type_link",
        sa.Column("priority", sa.Integer(), nullable=True),
    )
    session = Session(bind=op.get_bind())
    for link in session.query(ProjectTaskTypeLink).all():
        link.id = fields.gen_uuid()
        session.add(link)
        session.commit()
    op.alter_column("project_task_type_link", "id", nullable=False)
    op.create_unique_constraint(
        "project_tasktype_uc",
        "project_task_type_link",
        ["project_id", "task_type_id"],
    )
コード例 #32
0
    def deleteHardwareProfile(self, session: Session, name: str) -> None:
        """
        Delete hardwareProfile from the db.

            Returns:
                None
            Throws:
                HardwareProfileNotFound
                DbError
                TortugaException
        """

        try:
            hwProfile = self._hardwareProfilesDbHandler.getHardwareProfile(
                session, name)

            if hwProfile.nodes:
                raise TortugaException(
                    'Unable to delete hardware profile with associated nodes')

            # First delete the mappings
            hwProfile.mappedsoftwareprofiles = []

            self._logger.debug('Marking hardware profile [%s] for deletion' %
                               (name))

            session.delete(hwProfile)

            session.commit()
        except TortugaException:
            session.rollback()
            raise
        except Exception as ex:
            session.rollback()
            self._logger.exception(str(ex))
            raise
コード例 #33
0
ファイル: migrations.py プロジェクト: CIMAC-CIDC/cidc-api-gae
def migration_session():
    session = Session(bind=op.get_bind())
    task_queue = RollbackableQueue()

    try:
        yield session, task_queue
        print("Commiting SQL session...")
        session.commit()
        print("Session commit succeeded.")
    except Exception as e:
        print(f"Encountered exception: {e.__class__}\n{e}")
        print("Running SQL rollback...")
        session.rollback()
        print("SQL rollback succeeded.")
        if task_queue:
            try:
                print("Running GCS rollback...")
                task_queue.rollback()
                print("GCS rollback succeeded.")
            except Exception as e:
                print(f"GCS rollback failed: {e.__class__}\n{e}")
        raise
    finally:
        session.close()
コード例 #34
0
    def update_import_errors(session: Session, dagbag: DagBag) -> None:
        """
        For the DAGs in the given DagBag, record any associated import errors and clears
        errors for files that no longer have them. These are usually displayed through the
        Airflow UI so that users know that there are issues parsing DAGs.

        :param session: session for ORM operations
        :type session: sqlalchemy.orm.session.Session
        :param dagbag: DagBag containing DAGs with import errors
        :type dagbag: airflow.DagBag
        """
        # Clear the errors of the processed files
        for dagbag_file in dagbag.file_last_changed:
            session.query(errors.ImportError).filter(
                errors.ImportError.filename.startswith(dagbag_file)).delete(
                    synchronize_session="fetch")

        # Add the errors of the processed files
        for filename, stacktrace in dagbag.import_errors.items():
            session.add(
                errors.ImportError(filename=filename,
                                   timestamp=timezone.utcnow(),
                                   stacktrace=stacktrace))
        session.commit()
コード例 #35
0
ファイル: kitDbApi.py プロジェクト: tprestegard/tortuga
    def deleteKit(self, session: Session, name, version, iteration, force=False):
        """
        Delete kit from the db.

        Raises:
            KitNotFound
            KitInUse
            DbError
        """

        try:
            self._kitsDbHandler.deleteKit(
                session, name, version, iteration, force=force)

            session.commit()
        except TortugaException:
            session.rollback()

            raise
        except Exception as ex:
            session.rollback()

            self._logger.exception(str(ex))
            raise
コード例 #36
0
    def create(session: Session, repo_dir: str, issue_id: int, author: str,
               title: str, description: str, labels: list,
               is_pull_request: bool) -> Issue:
        """ Creates a new Issue record.

        Args:
            session (Session): The database session.
            repo_dir (str): The issue repository.
            issue_id (int): The issue identifier.
            author (str): The issue author.
            title (str): The issue title.
            description (str): The issue description.
            labels (list): The issue labels.
            is_pull_request (bool): If true the issue is a pull request, otherwise false.

        Raises:
            ValueError: Thrown when missing repo_dir, issue_id, author or title.
            IssueAlreadyExistError:
                Thrown when already exists an issuewith the same issue_id and repo_dir.

        Returns:
            Issue: The issue.
        """
        if not repo_dir or issue_id is None or not author or not title:
            raise ValueError("You cannot create an issue without a repo_dir, "
                             "an issue_id, an author and a title.")
        try:
            labels_1: str = json.dumps(labels)
            issue: Issue = Issue(repo_dir, issue_id, author, title,
                                 description, labels_1, is_pull_request)
            session.add(issue)
            session.commit()
            return issue
        except IntegrityError as err:
            session.rollback()
            raise IssueAlreadyExistsError from err
コード例 #37
0
ファイル: user.py プロジェクト: vddesai1871/hydrus
def add_token(request: LocalProxy, session: Session) -> str:
    """
    Create a new token for the user or return a
    valid existing token to the user.
    """
    token = None
    id_ = int(request.authorization['username'])
    try:
        token = session.query(Token).filter(Token.user_id == id_).one()
        present = datetime.now()
        present = present - token.timestamp
        if present > timedelta(0, 0, 0, 0, 45, 0, 0):
            update_token = '%030x' % randrange(16**30)
            token.id = update_token
            token.timestamp = datetime.now()
            session.commit()
    except NoResultFound:
        token = '%030x' % randrange(16**30)
        time = datetime.now()
        new_token = Token(user_id=id_, id=token, timestamp=time)
        session.add(new_token)
        session.commit()
        return token
    return token.id
コード例 #38
0
    def addPartition(self, session: Session, partitionName: str,
                     softwareProfileName: str) -> None:
        """
        Add software profile partition.

            Returns:
                partitionId
            Throws:
                PartitionAlreadyExists
                SoftwareProfileNotFound
        """

        try:
            self._softwareProfilesDbHandler.addPartitionToSoftwareProfile(
                session, partitionName, softwareProfileName)

            session.commit()
        except TortugaException:
            session.rollback()
            raise
        except Exception as ex:
            session.rollback()
            self._logger.exception(str(ex))
            raise
コード例 #39
0
def addCompartmentalizedComponent(componentId, compartmentId):
    session = Session()
    if not session.query(CompartmentalizedComponent).filter(
            CompartmentalizedComponent.component_id == componentId).filter(
                CompartmentalizedComponent.compartment_id ==
                compartmentId).count():
        try:
            component = session.query(Component).filter(
                Component.id == componentId).one()
        except:
            print "model does not exist in database"
            raise
        try:
            compartment = session.query(Compartment).filter(
                Compartment.id == compartmentId).one()
        except:
            print "compartment does not exist in database"
            raise
        cc = CompartmentalizedComponent(component_id=component.id,
                                        compartment_id=compartment.id)
        session.add(cc)
        session.commit()
        session.close()
        return cc
コード例 #40
0
def addReactionMatrix(reactionId, compartmentalizedComponentId):
    session = Session()
    if not session.query(ReactionMatrix).filter(
            ReactionMatrix.reaction_id == reactionId).filter(
                ReactionMatrix.compartmentalized_component_id ==
                compartmentalizedComponentId).count():
        try:
            cc = session.query(CompartmentalizedComponent).filter(
                cc.id == compartmentalizedComponentId).one()
        except:
            print "compartmentalized component does not exist in database"
            raise
        try:
            reaction = session.query(Reaction).filter(
                Reaction.id == reactionId).one()
        except:
            print "reaction does not exist in database"
            raise
        rm = ReactionMatrix(reaction_id=reaction.id,
                            compartmentalized_component_id=cc.id)
        session.add(rm)
        session.commit()
        session.close()
        return rm
コード例 #41
0
def process_file(session: Session, file: File,
                 callback: Callable[[], None]) -> bool:
    if file.processing_started_at:
        return False

    # Claim this file by updating the `processing_started_at` timestamp in such
    # a way that it must not have been set before.
    processing_started_at = datetime.datetime.now(timezone.utc)
    result = session.execute(
        update(File.__table__)  # pylint: disable=no-member
        .where(File.id == file.id).where(
            File.processing_started_at.is_(None)).values(
                processing_started_at=processing_started_at))
    if result.rowcount == 0:
        return False
    # If we got this far, `file` is ours to process.
    try:
        session.begin_nested()
        callback()
        file.processing_started_at = processing_started_at
        file.processing_completed_at = datetime.datetime.now(timezone.utc)
        session.add(file)
        session.commit()
        return True
    except Exception as error:
        session.rollback()
        file.processing_started_at = processing_started_at
        file.processing_completed_at = datetime.datetime.now(timezone.utc)
        # Some errors stringify nicely, some don't (e.g. StopIteration) so we
        # have to format them.
        file.processing_error = str(error) or str(
            traceback.format_exception(error.__class__, error,
                                       error.__traceback__))
        if not isinstance(error, UserError):
            raise error
        return True
def upgrade():
    session = Session(bind=op.get_bind())
    changes = session.query(Sticker).filter(Sticker.file_id.is_(None)).delete()
    session.commit()
コード例 #43
0
class DatabaseTest(object):

    engine = None
    connection = None

    @classmethod
    def get_database_connection(cls):
        url = Configuration.database_url()
        engine, connection = SessionManager.initialize(url)

        return engine, connection

    @classmethod
    def setup_class(cls):
        # Initialize a temporary data directory.
        cls.engine, cls.connection = cls.get_database_connection()
        cls.old_data_dir = Configuration.data_directory
        cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp")
        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir

        # Avoid CannotLoadConfiguration errors related to CDN integrations.
        Configuration.instance[
            Configuration.INTEGRATIONS] = Configuration.instance.get(
                Configuration.INTEGRATIONS, {})
        Configuration.instance[Configuration.INTEGRATIONS][
            ExternalIntegration.CDN] = {}

    @classmethod
    def teardown_class(cls):
        # Destroy the database connection and engine.
        cls.connection.close()
        cls.engine.dispose()

        if cls.tmp_data_dir.startswith("/tmp"):
            logging.debug("Removing temporary directory %s" % cls.tmp_data_dir)
            shutil.rmtree(cls.tmp_data_dir)

        else:
            logging.warn(
                "Cowardly refusing to remove 'temporary' directory %s" %
                cls.tmp_data_dir)

        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir

    def setup(self, mock_search=True):
        # Create a new connection to the database.
        self._db = Session(self.connection)
        self.transaction = self.connection.begin_nested()

        # Start with a high number so it won't interfere with tests that search for an age or grade
        self.counter = 2000

        self.time_counter = datetime(2014, 1, 1)
        self.isbns = [
            "9780674368279", "0636920028468", "9781936460236", "9780316075978"
        ]
        if mock_search:
            self.search_mock = mock.patch(
                external_search.__name__ + ".ExternalSearchIndex",
                MockExternalSearchIndex)
            self.search_mock.start()
        else:
            self.search_mock = None

        # TODO:  keeping this for now, but need to fix it bc it hits _isbn,
        # which pops an isbn off the list and messes tests up.  so exclude
        # _ functions from participating.
        # also attempt to stop nosetest showing docstrings instead of function names.
        #for name, obj in inspect.getmembers(self):
        #    if inspect.isfunction(obj) and obj.__name__.startswith('test_'):
        #        obj.__doc__ = None

    def teardown(self):
        # Close the session.
        self._db.close()

        # Roll back all database changes that happened during this
        # test, whether in the session that was just closed or some
        # other session.
        self.transaction.rollback()

        # Remove any database objects cached in the model classes but
        # associated with the now-rolled-back session.
        Collection.reset_cache()
        ConfigurationSetting.reset_cache()
        DataSource.reset_cache()
        DeliveryMechanism.reset_cache()
        ExternalIntegration.reset_cache()
        Genre.reset_cache()
        Library.reset_cache()

        # Also roll back any record of those changes in the
        # Configuration instance.
        for key in [
                Configuration.SITE_CONFIGURATION_LAST_UPDATE,
                Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
        ]:
            if key in Configuration.instance:
                del (Configuration.instance[key])

        if self.search_mock:
            self.search_mock.stop()

    def shortDescription(self):
        return None  # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.

    @property
    def _id(self):
        self.counter += 1
        return self.counter

    @property
    def _str(self):
        return unicode(self._id)

    @property
    def _time(self):
        v = self.time_counter
        self.time_counter = self.time_counter + timedelta(days=1)
        return v

    @property
    def _isbn(self):
        return self.isbns.pop()

    @property
    def _url(self):
        return "http://foo.com/" + self._str

    def _patron(self, external_identifier=None, library=None):
        external_identifier = external_identifier or self._str
        library = library or self._default_library
        return get_one_or_create(self._db,
                                 Patron,
                                 external_identifier=external_identifier,
                                 library=library)[0]

    def _contributor(self, sort_name=None, name=None, **kw_args):
        name = sort_name or name or self._str
        return get_one_or_create(self._db,
                                 Contributor,
                                 sort_name=unicode(name),
                                 **kw_args)

    def _identifier(self,
                    identifier_type=Identifier.GUTENBERG_ID,
                    foreign_id=None):
        if foreign_id:
            id = foreign_id
        else:
            id = self._str
        return Identifier.for_foreign_id(self._db, identifier_type, id)[0]

    def _edition(self,
                 data_source_name=DataSource.GUTENBERG,
                 identifier_type=Identifier.GUTENBERG_ID,
                 with_license_pool=False,
                 with_open_access_download=False,
                 title=None,
                 language="eng",
                 authors=None,
                 identifier_id=None,
                 series=None,
                 collection=None):
        id = identifier_id or self._str
        source = DataSource.lookup(self._db, data_source_name)
        wr = Edition.for_foreign_id(self._db, source, identifier_type, id)[0]
        if not title:
            title = self._str
        wr.title = unicode(title)
        wr.medium = Edition.BOOK_MEDIUM
        if series:
            wr.series = series
        if language:
            wr.language = language
        if authors is None:
            authors = self._str
        if isinstance(authors, basestring):
            authors = [authors]
        if authors != []:
            wr.add_contributor(unicode(authors[0]),
                               Contributor.PRIMARY_AUTHOR_ROLE)
            wr.author = unicode(authors[0])
        for author in authors[1:]:
            wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE)

        if with_license_pool or with_open_access_download:
            pool = self._licensepool(
                wr,
                data_source_name=data_source_name,
                with_open_access_download=with_open_access_download,
                collection=collection)

            pool.set_presentation_edition()
            return wr, pool
        return wr

    def _work(self,
              title=None,
              authors=None,
              genre=None,
              language=None,
              audience=None,
              fiction=True,
              with_license_pool=False,
              with_open_access_download=False,
              quality=0.5,
              series=None,
              presentation_edition=None,
              collection=None,
              data_source_name=None):
        """Create a Work.

        For performance reasons, this method does not generate OPDS
        entries or calculate a presentation edition for the new
        Work. Tests that rely on this information being present
        should call _slow_work() instead, which takes more care to present
        the sort of Work that would be created in a real environment.
        """
        pools = []
        if with_open_access_download:
            with_license_pool = True
        language = language or "eng"
        title = unicode(title or self._str)
        audience = audience or Classifier.AUDIENCE_ADULT
        if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name:
            # TODO: This is necessary because Gutenberg's childrens books
            # get filtered out at the moment.
            data_source_name = DataSource.OVERDRIVE
        elif not data_source_name:
            data_source_name = DataSource.GUTENBERG
        if fiction is None:
            fiction = True
        new_edition = False
        if not presentation_edition:
            new_edition = True
            presentation_edition = self._edition(
                title=title,
                language=language,
                authors=authors,
                with_license_pool=with_license_pool,
                with_open_access_download=with_open_access_download,
                data_source_name=data_source_name,
                series=series,
                collection=collection,
            )
            if with_license_pool:
                presentation_edition, pool = presentation_edition
                if with_open_access_download:
                    pool.open_access = True
                pools = [pool]
        else:
            pools = presentation_edition.license_pools
        work, ignore = get_one_or_create(self._db,
                                         Work,
                                         create_method_kwargs=dict(
                                             audience=audience,
                                             fiction=fiction,
                                             quality=quality),
                                         id=self._id)
        if genre:
            if not isinstance(genre, Genre):
                genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
            work.genres = [genre]
        work.random = 0.5
        work.set_presentation_edition(presentation_edition)

        if pools:
            # make sure the pool's presentation_edition is set,
            # bc loan tests assume that.
            if not work.license_pools:
                for pool in pools:
                    work.license_pools.append(pool)

            for pool in pools:
                pool.set_presentation_edition()

            # This is probably going to be used in an OPDS feed, so
            # fake that the work is presentation ready.
            work.presentation_ready = True
            work.calculate_opds_entries(verbose=False)

        return work

    def add_to_materialized_view(self, works, true_opds=False):
        """Make sure all the works in `works` show up in the materialized view.

        :param true_opds: Generate real OPDS entries for each each work,
        rather than faking it.
        """
        if not isinstance(works, list):
            works = [works]
        for work in works:
            if true_opds:
                work.calculate_opds_entries(verbose=False)
            else:
                work.presentation_ready = True
                work.simple_opds_entry = "<entry>an entry</entry>"
        self._db.commit()
        SessionManager.refresh_materialized_views(self._db)

    def _lane(self,
              display_name=None,
              library=None,
              parent=None,
              genres=None,
              languages=None,
              fiction=None):
        display_name = display_name or self._str
        library = library or self._default_library
        lane, is_new = get_one_or_create(
            self._db,
            Lane,
            library=library,
            parent=parent,
            display_name=display_name,
            create_method_kwargs=dict(fiction=fiction))
        if is_new and parent:
            lane.priority = len(parent.sublanes) - 1
        if genres:
            if not isinstance(genres, list):
                genres = [genres]
            for genre in genres:
                if isinstance(genre, basestring):
                    genre, ignore = Genre.lookup(self._db, genre)
                lane.genres.append(genre)
        if languages:
            if not isinstance(languages, list):
                languages = [languages]
            lane.languages = languages
        return lane

    def _slow_work(self, *args, **kwargs):
        """Create a work that closely resembles one that might be found in the
        wild.

        This is significantly slower than _work() but more reliable.
        """
        work = self._work(*args, **kwargs)
        work.calculate_presentation_edition()
        work.calculate_opds_entries(verbose=False)
        return work

    def _add_generic_delivery_mechanism(self, license_pool):
        """Give a license pool a generic non-open-access delivery mechanism."""
        data_source = license_pool.data_source
        identifier = license_pool.identifier
        content_type = Representation.EPUB_MEDIA_TYPE
        drm_scheme = DeliveryMechanism.NO_DRM
        LicensePoolDeliveryMechanism.set(data_source, identifier, content_type,
                                         drm_scheme, RightsStatus.IN_COPYRIGHT)

    def _coverage_record(
        self,
        edition,
        coverage_source,
        operation=None,
        status=CoverageRecord.SUCCESS,
        collection=None,
        exception=None,
    ):
        if isinstance(edition, Identifier):
            identifier = edition
        else:
            identifier = edition.primary_identifier
        record, ignore = get_one_or_create(self._db,
                                           CoverageRecord,
                                           identifier=identifier,
                                           data_source=coverage_source,
                                           operation=operation,
                                           collection=collection,
                                           create_method_kwargs=dict(
                                               timestamp=datetime.utcnow(),
                                               status=status,
                                               exception=exception,
                                           ))
        return record

    def _work_coverage_record(self,
                              work,
                              operation=None,
                              status=CoverageRecord.SUCCESS):
        record, ignore = get_one_or_create(self._db,
                                           WorkCoverageRecord,
                                           work=work,
                                           operation=operation,
                                           create_method_kwargs=dict(
                                               timestamp=datetime.utcnow(),
                                               status=status,
                                           ))
        return record

    def _licensepool(self,
                     edition,
                     open_access=True,
                     data_source_name=DataSource.GUTENBERG,
                     with_open_access_download=False,
                     set_edition_as_presentation=False,
                     collection=None):
        source = DataSource.lookup(self._db, data_source_name)
        if not edition:
            edition = self._edition(data_source_name)
        collection = collection or self._default_collection
        pool, ignore = get_one_or_create(
            self._db,
            LicensePool,
            create_method_kwargs=dict(open_access=open_access),
            identifier=edition.primary_identifier,
            data_source=source,
            collection=collection,
            availability_time=datetime.utcnow())

        if set_edition_as_presentation:
            pool.presentation_edition = edition

        if with_open_access_download:
            pool.open_access = True
            url = "http://foo.com/" + self._str
            media_type = MediaTypes.EPUB_MEDIA_TYPE
            link, new = pool.identifier.add_link(
                Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type)

            # Add a DeliveryMechanism for this download
            pool.set_delivery_mechanism(
                media_type,
                DeliveryMechanism.NO_DRM,
                RightsStatus.GENERIC_OPEN_ACCESS,
                link.resource,
            )

            representation, is_new = self._representation(url,
                                                          media_type,
                                                          "Dummy content",
                                                          mirrored=True)
            link.resource.representation = representation
        else:

            # Add a DeliveryMechanism for this licensepool
            pool.set_delivery_mechanism(MediaTypes.EPUB_MEDIA_TYPE,
                                        DeliveryMechanism.ADOBE_DRM,
                                        RightsStatus.UNKNOWN, None)
            pool.licenses_owned = pool.licenses_available = 1

        return pool

    def _representation(self,
                        url=None,
                        media_type=None,
                        content=None,
                        mirrored=False):
        url = url or "http://foo.com/" + self._str
        repr, is_new = get_one_or_create(self._db, Representation, url=url)
        repr.media_type = media_type
        if media_type and content:
            repr.content = content
            repr.fetched_at = datetime.utcnow()
            if mirrored:
                repr.mirror_url = "http://foo.com/" + self._str
                repr.mirrored_at = datetime.utcnow()
        return repr, is_new

    def _customlist(self,
                    foreign_identifier=None,
                    name=None,
                    data_source_name=DataSource.NYT,
                    num_entries=1,
                    entries_exist_as_works=True):
        data_source = DataSource.lookup(self._db, data_source_name)
        foreign_identifier = foreign_identifier or self._str
        now = datetime.utcnow()
        customlist, ignore = get_one_or_create(
            self._db,
            CustomList,
            create_method_kwargs=dict(
                created=now,
                updated=now,
                name=name or self._str,
                description=self._str,
            ),
            data_source=data_source,
            foreign_identifier=foreign_identifier)

        editions = []
        for i in range(num_entries):
            if entries_exist_as_works:
                work = self._work(with_open_access_download=True)
                edition = work.presentation_edition
            else:
                edition = self._edition(data_source_name, title="Item %s" % i)
                edition.permanent_work_id = "Permanent work ID %s" % self._str
            customlist.add_entry(edition,
                                 "Annotation %s" % i,
                                 first_appearance=now)
            editions.append(edition)
        return customlist, editions

    def _complaint(self, license_pool, type, source, detail, resolved=None):
        complaint, is_new = Complaint.register(license_pool, type, source,
                                               detail, resolved)
        return complaint

    def _credential(self,
                    data_source_name=DataSource.GUTENBERG,
                    type=None,
                    patron=None):
        data_source = DataSource.lookup(self._db, data_source_name)
        type = type or self._str
        patron = patron or self._patron()
        credential, is_new = Credential.persistent_token_create(
            self._db, data_source, type, patron)
        return credential

    def _external_integration(self,
                              protocol,
                              goal=None,
                              settings=None,
                              libraries=None,
                              **kwargs):
        integration = None
        if not libraries:
            integration, ignore = get_one_or_create(self._db,
                                                    ExternalIntegration,
                                                    protocol=protocol,
                                                    goal=goal)
        else:
            if not isinstance(libraries, list):
                libraries = [libraries]

            # Try to find an existing integration for one of the given
            # libraries.
            for library in libraries:
                integration = ExternalIntegration.lookup(self._db,
                                                         protocol,
                                                         goal,
                                                         library=libraries[0])
                if integration:
                    break

            if not integration:
                # Otherwise, create a brand new integration specifically
                # for the library.
                integration = ExternalIntegration(
                    protocol=protocol,
                    goal=goal,
                )
                integration.libraries.extend(libraries)
                self._db.add(integration)

        for attr, value in kwargs.items():
            setattr(integration, attr, value)

        settings = settings or dict()
        for key, value in settings.items():
            integration.set_setting(key, value)

        return integration

    def _delegated_patron_identifier(
            self,
            library_uri=None,
            patron_identifier=None,
            identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
            identifier=None):
        """Create a sample DelegatedPatronIdentifier"""
        library_uri = library_uri or self._url
        patron_identifier = patron_identifier or self._str
        if callable(identifier):
            make_id = identifier
        else:
            if not identifier:
                identifier = self._str

            def make_id():
                return identifier

        patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
            self._db, library_uri, patron_identifier, identifier_type, make_id)
        return patron

    def _sample_ecosystem(self):
        """ Creates an ecosystem of some sample work, pool, edition, and author
        objects that all know each other.
        """
        # make some authors
        [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob")
        bob.family_name, bob.display_name = bob.default_names()
        [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice")
        alice.family_name, alice.display_name = alice.default_names()

        edition_std_ebooks, pool_std_ebooks = self._edition(
            DataSource.STANDARD_EBOOKS,
            Identifier.URI,
            with_license_pool=True,
            with_open_access_download=True,
            authors=[])
        edition_std_ebooks.title = u"The Standard Ebooks Title"
        edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle"
        edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG,
                                              Identifier.GUTENBERG_ID,
                                              with_license_pool=True,
                                              with_open_access_download=True,
                                              authors=[])
        edition_git.title = u"The GItenberg Title"
        edition_git.subtitle = u"The GItenberg Subtitle"
        edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
        edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_gut, pool_gut = self._edition(DataSource.GUTENBERG,
                                              Identifier.GUTENBERG_ID,
                                              with_license_pool=True,
                                              with_open_access_download=True,
                                              authors=[])
        edition_gut.title = u"The GUtenberg Title"
        edition_gut.subtitle = u"The GUtenberg Subtitle"
        edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)

        work = self._work(presentation_edition=edition_git)

        for p in pool_gut, pool_std_ebooks:
            work.license_pools.append(p)

        work.calculate_presentation()

        return (work, pool_std_ebooks, pool_git, pool_gut, edition_std_ebooks,
                edition_git, edition_gut, alice, bob)

    def print_database_instance(self):
        """
        Calls the class method that examines the current state of the database model
        (whether it's been committed or not).

        NOTE:  If you set_trace, and hit "continue", you'll start seeing console output right
        away, without waiting for the whole test to run and the standard output section to display.
        You can also use nosetest --nocapture.
        I use:
        def test_name(self):
            [code...]
            set_trace()
            self.print_database_instance()  # TODO: remove before prod
            [code...]
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn(
                "Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production."
            )
            return

        DatabaseTest.print_database_class(self._db)
        return

    @classmethod
    def print_database_class(cls, db_connection):
        """
        Prints to the console the entire contents of the database, as the unit test sees it.
        Exists because unit tests don't persist db information, they create a memory
        representation of the db state, and then roll the unit test-derived transactions back.
        So we cannot see what's going on by going into postgres and running selects.
        This is the in-test alternative to going into postgres.

        Can be called from model and metadata classes as well as tests.

        NOTE: The purpose of this method is for debugging.
        Be careful of leaving it in code and potentially outputting
        vast tracts of data into your output stream on production.

        Call like this:
        set_trace()
        from testing import (
            DatabaseTest,
        )
        _db = Session.object_session(self)
        DatabaseTest.print_database_class(_db)  # TODO: remove before prod
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn(
                "Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production."
            )
            return

        works = db_connection.query(Work).all()
        identifiers = db_connection.query(Identifier).all()
        license_pools = db_connection.query(LicensePool).all()
        editions = db_connection.query(Edition).all()
        data_sources = db_connection.query(DataSource).all()
        representations = db_connection.query(Representation).all()

        if (not works):
            print "NO Work found"
        for wCount, work in enumerate(works):
            # pipe character at end of line helps see whitespace issues
            print "Work[%s]=%s|" % (wCount, work)

            if (not work.license_pools):
                print "    NO Work.LicensePool found"
            for lpCount, license_pool in enumerate(work.license_pools):
                print "    Work.LicensePool[%s]=%s|" % (lpCount, license_pool)

            print "    Work.presentation_edition=%s|" % work.presentation_edition

        print "__________________________________________________________________\n"
        if (not identifiers):
            print "NO Identifier found"
        for iCount, identifier in enumerate(identifiers):
            print "Identifier[%s]=%s|" % (iCount, identifier)
            print "    Identifier.licensed_through=%s|" % identifier.licensed_through

        print "__________________________________________________________________\n"
        if (not license_pools):
            print "NO LicensePool found"
        for index, license_pool in enumerate(license_pools):
            print "LicensePool[%s]=%s|" % (index, license_pool)
            print "    LicensePool.work_id=%s|" % license_pool.work_id
            print "    LicensePool.data_source_id=%s|" % license_pool.data_source_id
            print "    LicensePool.identifier_id=%s|" % license_pool.identifier_id
            print "    LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id
            print "    LicensePool.superceded=%s|" % license_pool.superceded
            print "    LicensePool.suppressed=%s|" % license_pool.suppressed

        print "__________________________________________________________________\n"
        if (not editions):
            print "NO Edition found"
        for index, edition in enumerate(editions):
            # pipe character at end of line helps see whitespace issues
            print "Edition[%s]=%s|" % (index, edition)
            print "    Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
            print "    Edition.permanent_work_id=%s|" % edition.permanent_work_id
            if (edition.data_source):
                print "    Edition.data_source.id=%s|" % edition.data_source.id
                print "    Edition.data_source.name=%s|" % edition.data_source.name
            else:
                print "    No Edition.data_source."
            if (edition.license_pool):
                print "    Edition.license_pool.id=%s|" % edition.license_pool.id
            else:
                print "    No Edition.license_pool."

            print "    Edition.title=%s|" % edition.title
            print "    Edition.author=%s|" % edition.author
            if (not edition.author_contributors):
                print "    NO Edition.author_contributor found"
            for acCount, author_contributor in enumerate(
                    edition.author_contributors):
                print "    Edition.author_contributor[%s]=%s|" % (
                    acCount, author_contributor)

        print "__________________________________________________________________\n"
        if (not data_sources):
            print "NO DataSource found"
        for index, data_source in enumerate(data_sources):
            print "DataSource[%s]=%s|" % (index, data_source)
            print "    DataSource.id=%s|" % data_source.id
            print "    DataSource.name=%s|" % data_source.name
            print "    DataSource.offers_licenses=%s|" % data_source.offers_licenses
            print "    DataSource.editions=%s|" % data_source.editions
            print "    DataSource.license_pools=%s|" % data_source.license_pools
            print "    DataSource.links=%s|" % data_source.links

        print "__________________________________________________________________\n"
        if (not representations):
            print "NO Representation found"
        for index, representation in enumerate(representations):
            print "Representation[%s]=%s|" % (index, representation)
            print "    Representation.id=%s|" % representation.id
            print "    Representation.url=%s|" % representation.url
            print "    Representation.mirror_url=%s|" % representation.mirror_url
            print "    Representation.fetch_exception=%s|" % representation.fetch_exception
            print "    Representation.mirror_exception=%s|" % representation.mirror_exception

        return

    def _library(self, name=None, short_name=None):
        name = name or self._str
        short_name = short_name or self._str
        library, ignore = get_one_or_create(
            self._db,
            Library,
            name=name,
            short_name=short_name,
            create_method_kwargs=dict(uuid=str(uuid.uuid4())),
        )
        return library

    def _collection(self,
                    name=None,
                    protocol=ExternalIntegration.OPDS_IMPORT,
                    external_account_id=None,
                    url=None,
                    username=None,
                    password=None,
                    data_source_name=None):
        name = name or self._str
        collection, ignore = get_one_or_create(self._db, Collection, name=name)
        collection.external_account_id = external_account_id
        integration = collection.create_external_integration(protocol)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        integration.url = url
        integration.username = username
        integration.password = password

        if data_source_name:
            collection.data_source = data_source_name
        return collection

    @property
    def _default_library(self):
        """A Library that will only be created once throughout a given test.

        By default, the `_default_collection` will be associated with
        the default library.
        """
        if not hasattr(self, '_default__library'):
            self._default__library = self.make_default_library(self._db)
        return self._default__library

    @property
    def _default_collection(self):
        """A Collection that will only be created once throughout
        a given test.

        For most tests there's no need to create a different
        Collection for every LicensePool. Using
        self._default_collection instead of calling self.collection()
        saves time.
        """
        if not hasattr(self, '_default__collection'):
            self._default__collection = self._default_library.collections[0]
        return self._default__collection

    @classmethod
    def make_default_library(cls, _db):
        """Ensure that the default library exists in the given database.

        This can be called by code intended for use in testing but not actually
        within a DatabaseTest subclass.
        """
        library, ignore = get_one_or_create(_db,
                                            Library,
                                            create_method_kwargs=dict(
                                                uuid=unicode(uuid.uuid4()),
                                                name="default",
                                            ),
                                            short_name="default")
        collection, ignore = get_one_or_create(_db,
                                               Collection,
                                               name="Default Collection")
        integration = collection.create_external_integration(
            ExternalIntegration.OPDS_IMPORT)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        if collection not in library.collections:
            library.collections.append(collection)
        return library

    def _catalog(self, name=u"Faketown Public Library"):
        source, ignore = get_one_or_create(self._db, DataSource, name=name)

    def _integration_client(self, url=None, shared_secret=None):
        url = url or self._url
        secret = shared_secret or u"secret"
        return get_one_or_create(self._db,
                                 IntegrationClient,
                                 shared_secret=secret,
                                 create_method_kwargs=dict(url=url))[0]

    def _subject(self, type, identifier):
        return get_one_or_create(self._db,
                                 Subject,
                                 type=type,
                                 identifier=identifier)[0]

    def _classification(self, identifier, subject, data_source, weight=1):
        return get_one_or_create(self._db,
                                 Classification,
                                 identifier=identifier,
                                 subject=subject,
                                 data_source=data_source,
                                 weight=weight)[0]

    def sample_cover_path(self, name):
        """The path to the sample cover with the given filename."""
        base_path = os.path.split(__file__)[0]
        resource_path = os.path.join(base_path, "tests", "files", "covers")
        sample_cover_path = os.path.join(resource_path, name)
        return sample_cover_path

    def sample_cover_representation(self, name):
        """A Representation of the sample cover with the given filename."""
        sample_cover_path = self.sample_cover_path(name)
        return self._representation(media_type="image/png",
                                    content=open(sample_cover_path).read())[0]
コード例 #44
0
ファイル: cli.py プロジェクト: beenje/quetz
def _fill_test_database(db: Session) -> NoReturn:
    """Create dummy users and channels to allow further testing in dev mode."""
    from quetz.dao import Dao

    test_users = []
    dao = Dao(db)
    try:
        for index, username in enumerate(['alice', 'bob', 'carol', 'dave']):
            user = dao.create_user_with_role(username)

            identity = Identity(
                provider='dummy',
                identity_id=str(index),
                username=username,
            )

            profile = Profile(name=username.capitalize(),
                              avatar_url='/avatar.jpg')

            user.identities.append(identity)  # type: ignore
            user.profile = profile
            db.add(user)
            test_users.append(user)

        for channel_index in range(3):
            channel = Channel(
                name=f'channel{channel_index}',
                description=f'Description of channel{channel_index}',
                private=False,
            )

            for package_index in range(random.randint(5, 10)):
                package = Package(
                    name=f'package{package_index}',
                    summary=f'package {package_index} summary text',
                    description=f'Description of package{package_index}',
                )
                channel.packages.append(package)  # type: ignore

                test_user = test_users[random.randint(0, len(test_users) - 1)]
                package_member = PackageMember(package=package,
                                               channel=channel,
                                               user=test_user,
                                               role='owner')

                db.add(package_member)

            test_user = test_users[random.randint(0, len(test_users) - 1)]

            if channel_index == 0:
                package = Package(name='xtensor',
                                  description='Description of xtensor')
                channel.packages.append(package)  # type: ignore

                package_member = PackageMember(package=package,
                                               channel=channel,
                                               user=test_user,
                                               role='owner')

                db.add(package_member)

                # create API key
                key = uuid.uuid4().hex

                key_user = User(id=uuid.uuid4().bytes)
                api_key = ApiKey(key=key,
                                 description='test API key',
                                 user=test_user,
                                 owner=test_user)
                db.add(api_key)
                print(
                    f'Test API key created for user "{test_user.username}": {key}'
                )

                key_package_member = PackageMember(
                    user=key_user,
                    channel_name=channel.name,
                    package_name=package.name,
                    role='maintainer',
                )
                db.add(key_package_member)

            db.add(channel)

            channel_member = ChannelMember(
                channel=channel,
                user=test_user,
                role='owner',
            )

            db.add(channel_member)
        db.commit()
    finally:
        db.close()
コード例 #45
0
class DatabaseTest(object):

    engine = None
    connection = None

    @classmethod
    def get_database_connection(cls):
        url = Configuration.database_url()
        engine, connection = SessionManager.initialize(url)

        return engine, connection

    @classmethod
    def setup_class(cls):
        # Initialize a temporary data directory.
        cls.engine, cls.connection = cls.get_database_connection()
        cls.old_data_dir = Configuration.data_directory
        cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp")
        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir

        # Avoid CannotLoadConfiguration errors related to CDN integrations.
        Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get(
            Configuration.INTEGRATIONS, {}
        )
        Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {}

    @classmethod
    def teardown_class(cls):
        # Destroy the database connection and engine.
        cls.connection.close()
        cls.engine.dispose()

        if cls.tmp_data_dir.startswith("/tmp"):
            logging.debug("Removing temporary directory %s" % cls.tmp_data_dir)
            shutil.rmtree(cls.tmp_data_dir)

        else:
            logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir)

        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir

    def setup(self, mock_search=True):
        # Create a new connection to the database.
        self._db = Session(self.connection)
        self.transaction = self.connection.begin_nested()

        # Start with a high number so it won't interfere with tests that search for an age or grade
        self.counter = 2000

        self.time_counter = datetime(2014, 1, 1)
        self.isbns = [
            "9780674368279", "0636920028468", "9781936460236", "9780316075978"
        ]
        if mock_search:
            self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex)
            self.search_mock.start()
        else:
            self.search_mock = None

        # TODO:  keeping this for now, but need to fix it bc it hits _isbn,
        # which pops an isbn off the list and messes tests up.  so exclude
        # _ functions from participating.
        # also attempt to stop nosetest showing docstrings instead of function names.
        #for name, obj in inspect.getmembers(self):
        #    if inspect.isfunction(obj) and obj.__name__.startswith('test_'):
        #        obj.__doc__ = None


    def teardown(self):
        # Close the session.
        self._db.close()

        # Roll back all database changes that happened during this
        # test, whether in the session that was just closed or some
        # other session.
        self.transaction.rollback()

        # Remove any database objects cached in the model classes but
        # associated with the now-rolled-back session.
        Collection.reset_cache()
        ConfigurationSetting.reset_cache()
        DataSource.reset_cache()
        DeliveryMechanism.reset_cache()
        ExternalIntegration.reset_cache()
        Genre.reset_cache()
        Library.reset_cache()

        # Also roll back any record of those changes in the
        # Configuration instance.
        for key in [
                Configuration.SITE_CONFIGURATION_LAST_UPDATE,
                Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
        ]:
            if key in Configuration.instance:
                del(Configuration.instance[key])

        if self.search_mock:
            self.search_mock.stop()

    def time_eq(self, a, b):
        "Assert that two times are *approximately* the same -- within 2 seconds."
        if a < b:
            delta = b-a
        else:
            delta = a-b
        total_seconds = delta.total_seconds()
        assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds)

    def shortDescription(self):
        return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.

    @property
    def _id(self):
        self.counter += 1
        return self.counter

    @property
    def _str(self):
        return unicode(self._id)

    @property
    def _time(self):
        v = self.time_counter
        self.time_counter = self.time_counter + timedelta(days=1)
        return v

    @property
    def _isbn(self):
        return self.isbns.pop()

    @property
    def _url(self):
        return "http://foo.com/" + self._str

    def _patron(self, external_identifier=None, library=None):
        external_identifier = external_identifier or self._str
        library = library or self._default_library
        return get_one_or_create(
            self._db, Patron, external_identifier=external_identifier,
            library=library
        )[0]

    def _contributor(self, sort_name=None, name=None, **kw_args):
        name = sort_name or name or self._str
        return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args)

    def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None):
        if foreign_id:
            id = foreign_id
        else:
            id = self._str
        return Identifier.for_foreign_id(self._db, identifier_type, id)[0]

    def _edition(self, data_source_name=DataSource.GUTENBERG,
                 identifier_type=Identifier.GUTENBERG_ID,
                 with_license_pool=False, with_open_access_download=False,
                 title=None, language="eng", authors=None, identifier_id=None,
                 series=None, collection=None, publicationDate=None
    ):
        id = identifier_id or self._str
        source = DataSource.lookup(self._db, data_source_name)
        wr = Edition.for_foreign_id(
            self._db, source, identifier_type, id)[0]
        if not title:
            title = self._str
        wr.title = unicode(title)
        wr.medium = Edition.BOOK_MEDIUM
        if series:
            wr.series = series
        if language:
            wr.language = language
        if authors is None:
            authors = self._str
        if isinstance(authors, basestring):
            authors = [authors]
        if authors != []:
            wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE)
            wr.author = unicode(authors[0])
        for author in authors[1:]:
            wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE)
        if publicationDate:
            wr.published = publicationDate

        if with_license_pool or with_open_access_download:
            pool = self._licensepool(
                wr, data_source_name=data_source_name,
                with_open_access_download=with_open_access_download,
                collection=collection
            )

            pool.set_presentation_edition()
            return wr, pool
        return wr

    def _work(self, title=None, authors=None, genre=None, language=None,
              audience=None, fiction=True, with_license_pool=False,
              with_open_access_download=False, quality=0.5, series=None,
              presentation_edition=None, collection=None, data_source_name=None):
        """Create a Work.

        For performance reasons, this method does not generate OPDS
        entries or calculate a presentation edition for the new
        Work. Tests that rely on this information being present
        should call _slow_work() instead, which takes more care to present
        the sort of Work that would be created in a real environment.
        """
        pools = []
        if with_open_access_download:
            with_license_pool = True
        language = language or "eng"
        title = unicode(title or self._str)
        audience = audience or Classifier.AUDIENCE_ADULT
        if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name:
            # TODO: This is necessary because Gutenberg's childrens books
            # get filtered out at the moment.
            data_source_name = DataSource.OVERDRIVE
        elif not data_source_name:
            data_source_name = DataSource.GUTENBERG
        if fiction is None:
            fiction = True
        new_edition = False
        if not presentation_edition:
            new_edition = True
            presentation_edition = self._edition(
                title=title, language=language,
                authors=authors,
                with_license_pool=with_license_pool,
                with_open_access_download=with_open_access_download,
                data_source_name=data_source_name,
                series=series,
                collection=collection,
            )
            if with_license_pool:
                presentation_edition, pool = presentation_edition
                if with_open_access_download:
                    pool.open_access = True
                pools = [pool]
        else:
            pools = presentation_edition.license_pools
        work, ignore = get_one_or_create(
            self._db, Work, create_method_kwargs=dict(
                audience=audience,
                fiction=fiction,
                quality=quality), id=self._id)
        if genre:
            if not isinstance(genre, Genre):
                genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
            work.genres = [genre]
        work.random = 0.5
        work.set_presentation_edition(presentation_edition)

        if pools:
            # make sure the pool's presentation_edition is set,
            # bc loan tests assume that.
            if not work.license_pools:
                for pool in pools:
                    work.license_pools.append(pool)

            for pool in pools:
                pool.set_presentation_edition()

            # This is probably going to be used in an OPDS feed, so
            # fake that the work is presentation ready.
            work.presentation_ready = True
            work.calculate_opds_entries(verbose=False)

        return work

    def add_to_materialized_view(self, works, true_opds=False):
        """Make sure all the works in `works` show up in the materialized view.

        :param true_opds: Generate real OPDS entries for each each work,
        rather than faking it.
        """
        if not isinstance(works, list):
            works = [works]
        for work in works:
            if true_opds:
                work.calculate_opds_entries(verbose=False)
            else:
                work.presentation_ready = True
                work.simple_opds_entry = "<entry>an entry</entry>"
        self._db.commit()
        SessionManager.refresh_materialized_views(self._db)

    def _lane(self, display_name=None, library=None,
              parent=None, genres=None, languages=None,
              fiction=None
    ):
        display_name = display_name or self._str
        library = library or self._default_library
        lane, is_new = create(
            self._db, Lane,
            library=library,
            parent=parent, display_name=display_name,
            fiction=fiction
        )
        if is_new and parent:
            lane.priority = len(parent.sublanes)-1
        if genres:
            if not isinstance(genres, list):
                genres = [genres]
            for genre in genres:
                if isinstance(genre, basestring):
                    genre, ignore = Genre.lookup(self._db, genre)
                lane.genres.append(genre)
        if languages:
            if not isinstance(languages, list):
                languages = [languages]
            lane.languages = languages
        return lane

    def _slow_work(self, *args, **kwargs):
        """Create a work that closely resembles one that might be found in the
        wild.

        This is significantly slower than _work() but more reliable.
        """
        work = self._work(*args, **kwargs)
        work.calculate_presentation_edition()
        work.calculate_opds_entries(verbose=False)
        return work

    def _add_generic_delivery_mechanism(self, license_pool):
        """Give a license pool a generic non-open-access delivery mechanism."""
        data_source = license_pool.data_source
        identifier = license_pool.identifier
        content_type = Representation.EPUB_MEDIA_TYPE
        drm_scheme = DeliveryMechanism.NO_DRM
        LicensePoolDeliveryMechanism.set(
            data_source, identifier, content_type, drm_scheme,
            RightsStatus.IN_COPYRIGHT
        )

    def _coverage_record(self, edition, coverage_source, operation=None,
        status=CoverageRecord.SUCCESS, collection=None, exception=None,
    ):
        if isinstance(edition, Identifier):
            identifier = edition
        else:
            identifier = edition.primary_identifier
        record, ignore = get_one_or_create(
            self._db, CoverageRecord,
            identifier=identifier,
            data_source=coverage_source,
            operation=operation,
            collection=collection,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
                exception=exception,
            )
        )
        return record

    def _work_coverage_record(self, work, operation=None,
                              status=CoverageRecord.SUCCESS):
        record, ignore = get_one_or_create(
            self._db, WorkCoverageRecord,
            work=work,
            operation=operation,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
            )
        )
        return record

    def _licensepool(self, edition, open_access=True,
                     data_source_name=DataSource.GUTENBERG,
                     with_open_access_download=False,
                     set_edition_as_presentation=False,
                     collection=None):
        source = DataSource.lookup(self._db, data_source_name)
        if not edition:
            edition = self._edition(data_source_name)
        collection = collection or self._default_collection
        pool, ignore = get_one_or_create(
            self._db, LicensePool,
            create_method_kwargs=dict(
                open_access=open_access),
            identifier=edition.primary_identifier,
            data_source=source,
            collection=collection,
            availability_time=datetime.utcnow()
        )

        if set_edition_as_presentation:
            pool.presentation_edition = edition

        if with_open_access_download:
            pool.open_access = True
            url = "http://foo.com/" + self._str
            media_type = MediaTypes.EPUB_MEDIA_TYPE
            link, new = pool.identifier.add_link(
                Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
                source, media_type
            )

            # Add a DeliveryMechanism for this download
            pool.set_delivery_mechanism(
                media_type,
                DeliveryMechanism.NO_DRM,
                RightsStatus.GENERIC_OPEN_ACCESS,
                link.resource,
            )

            representation, is_new = self._representation(
                url, media_type, "Dummy content", mirrored=True)
            link.resource.representation = representation
        else:

            # Add a DeliveryMechanism for this licensepool
            pool.set_delivery_mechanism(
                MediaTypes.EPUB_MEDIA_TYPE,
                DeliveryMechanism.ADOBE_DRM,
                RightsStatus.UNKNOWN,
                None
            )
            pool.licenses_owned = pool.licenses_available = 1

        return pool

    def _license(self, pool, identifier=None, checkout_url=None, status_url=None,
                 expires=None, remaining_checkouts=None, concurrent_checkouts=None):
        identifier = identifier or self._str
        checkout_url = checkout_url or self._str
        status_url = status_url or self._str
        license, ignore = get_one_or_create(
            self._db, License, identifier=identifier, license_pool=pool,
            checkout_url=checkout_url,
            status_url=status_url, expires=expires,
            remaining_checkouts=remaining_checkouts,
            concurrent_checkouts=concurrent_checkouts,
        )
        return license

    def _representation(self, url=None, media_type=None, content=None,
                        mirrored=False):
        url = url or "http://foo.com/" + self._str
        repr, is_new = get_one_or_create(
            self._db, Representation, url=url)
        repr.media_type = media_type
        if media_type and content:
            repr.content = content
            repr.fetched_at = datetime.utcnow()
            if mirrored:
                repr.mirror_url = "http://foo.com/" + self._str
                repr.mirrored_at = datetime.utcnow()
        return repr, is_new

    def _customlist(self, foreign_identifier=None,
                    name=None,
                    data_source_name=DataSource.NYT, num_entries=1,
                    entries_exist_as_works=True
    ):
        data_source = DataSource.lookup(self._db, data_source_name)
        foreign_identifier = foreign_identifier or self._str
        now = datetime.utcnow()
        customlist, ignore = get_one_or_create(
            self._db, CustomList,
            create_method_kwargs=dict(
                created=now,
                updated=now,
                name=name or self._str,
                description=self._str,
                ),
            data_source=data_source,
            foreign_identifier=foreign_identifier
        )

        editions = []
        for i in range(num_entries):
            if entries_exist_as_works:
                work = self._work(with_open_access_download=True)
                edition = work.presentation_edition
            else:
                edition = self._edition(
                    data_source_name, title="Item %s" % i)
                edition.permanent_work_id="Permanent work ID %s" % self._str
            customlist.add_entry(
                edition, "Annotation %s" % i, first_appearance=now)
            editions.append(edition)
        return customlist, editions

    def _complaint(self, license_pool, type, source, detail, resolved=None):
        complaint, is_new = Complaint.register(
            license_pool,
            type,
            source,
            detail,
            resolved
        )
        return complaint

    def _credential(self, data_source_name=DataSource.GUTENBERG,
                    type=None, patron=None):
        data_source = DataSource.lookup(self._db, data_source_name)
        type = type or self._str
        patron = patron or self._patron()
        credential, is_new = Credential.persistent_token_create(
            self._db, data_source, type, patron
        )
        return credential

    def _external_integration(self, protocol, goal=None, settings=None,
                              libraries=None, **kwargs
    ):
        integration = None
        if not libraries:
            integration, ignore = get_one_or_create(
                self._db, ExternalIntegration, protocol=protocol, goal=goal
            )
        else:
            if not isinstance(libraries, list):
                libraries = [libraries]

            # Try to find an existing integration for one of the given
            # libraries.
            for library in libraries:
                integration = ExternalIntegration.lookup(
                    self._db, protocol, goal, library=libraries[0]
                )
                if integration:
                    break

            if not integration:
                # Otherwise, create a brand new integration specifically
                # for the library.
                integration = ExternalIntegration(
                    protocol=protocol, goal=goal,
                )
                integration.libraries.extend(libraries)
                self._db.add(integration)

        for attr, value in kwargs.items():
            setattr(integration, attr, value)

        settings = settings or dict()
        for key, value in settings.items():
            integration.set_setting(key, value)

        return integration

    def _delegated_patron_identifier(
            self, library_uri=None, patron_identifier=None,
            identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
            identifier=None
    ):
        """Create a sample DelegatedPatronIdentifier"""
        library_uri = library_uri or self._url
        patron_identifier = patron_identifier or self._str
        if callable(identifier):
            make_id = identifier
        else:
            if not identifier:
                identifier = self._str
            def make_id():
                return identifier
        patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
            self._db, library_uri, patron_identifier, identifier_type,
            make_id
        )
        return patron

    def _sample_ecosystem(self):
        """ Creates an ecosystem of some sample work, pool, edition, and author
        objects that all know each other.
        """
        # make some authors
        [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob")
        bob.family_name, bob.display_name = bob.default_names()
        [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice")
        alice.family_name, alice.display_name = alice.default_names()

        edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_std_ebooks.title = u"The Standard Ebooks Title"
        edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle"
        edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_git.title = u"The GItenberg Title"
        edition_git.subtitle = u"The GItenberg Subtitle"
        edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
        edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_gut.title = u"The GUtenberg Title"
        edition_gut.subtitle = u"The GUtenberg Subtitle"
        edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)

        work = self._work(presentation_edition=edition_git)

        for p in pool_gut, pool_std_ebooks:
            work.license_pools.append(p)

        work.calculate_presentation()

        return (work, pool_std_ebooks, pool_git, pool_gut,
            edition_std_ebooks, edition_git, edition_gut, alice, bob)


    def print_database_instance(self):
        """
        Calls the class method that examines the current state of the database model
        (whether it's been committed or not).

        NOTE:  If you set_trace, and hit "continue", you'll start seeing console output right
        away, without waiting for the whole test to run and the standard output section to display.
        You can also use nosetest --nocapture.
        I use:
        def test_name(self):
            [code...]
            set_trace()
            self.print_database_instance()  # TODO: remove before prod
            [code...]
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.")
            return

        DatabaseTest.print_database_class(self._db)
        return


    @classmethod
    def print_database_class(cls, db_connection):
        """
        Prints to the console the entire contents of the database, as the unit test sees it.
        Exists because unit tests don't persist db information, they create a memory
        representation of the db state, and then roll the unit test-derived transactions back.
        So we cannot see what's going on by going into postgres and running selects.
        This is the in-test alternative to going into postgres.

        Can be called from model and metadata classes as well as tests.

        NOTE: The purpose of this method is for debugging.
        Be careful of leaving it in code and potentially outputting
        vast tracts of data into your output stream on production.

        Call like this:
        set_trace()
        from testing import (
            DatabaseTest,
        )
        _db = Session.object_session(self)
        DatabaseTest.print_database_class(_db)  # TODO: remove before prod
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.")
            return

        works = db_connection.query(Work).all()
        identifiers = db_connection.query(Identifier).all()
        license_pools = db_connection.query(LicensePool).all()
        editions = db_connection.query(Edition).all()
        data_sources = db_connection.query(DataSource).all()
        representations = db_connection.query(Representation).all()

        if (not works):
            print "NO Work found"
        for wCount, work in enumerate(works):
            # pipe character at end of line helps see whitespace issues
            print "Work[%s]=%s|" % (wCount, work)

            if (not work.license_pools):
                print "    NO Work.LicensePool found"
            for lpCount, license_pool in enumerate(work.license_pools):
                print "    Work.LicensePool[%s]=%s|" % (lpCount, license_pool)

            print "    Work.presentation_edition=%s|" % work.presentation_edition

        print "__________________________________________________________________\n"
        if (not identifiers):
            print "NO Identifier found"
        for iCount, identifier in enumerate(identifiers):
            print "Identifier[%s]=%s|" % (iCount, identifier)
            print "    Identifier.licensed_through=%s|" % identifier.licensed_through

        print "__________________________________________________________________\n"
        if (not license_pools):
            print "NO LicensePool found"
        for index, license_pool in enumerate(license_pools):
            print "LicensePool[%s]=%s|" % (index, license_pool)
            print "    LicensePool.work_id=%s|" % license_pool.work_id
            print "    LicensePool.data_source_id=%s|" % license_pool.data_source_id
            print "    LicensePool.identifier_id=%s|" % license_pool.identifier_id
            print "    LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id
            print "    LicensePool.superceded=%s|" % license_pool.superceded
            print "    LicensePool.suppressed=%s|" % license_pool.suppressed

        print "__________________________________________________________________\n"
        if (not editions):
            print "NO Edition found"
        for index, edition in enumerate(editions):
            # pipe character at end of line helps see whitespace issues
            print "Edition[%s]=%s|" % (index, edition)
            print "    Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
            print "    Edition.permanent_work_id=%s|" % edition.permanent_work_id
            if (edition.data_source):
                print "    Edition.data_source.id=%s|" % edition.data_source.id
                print "    Edition.data_source.name=%s|" % edition.data_source.name
            else:
                print "    No Edition.data_source."
            if (edition.license_pool):
                print "    Edition.license_pool.id=%s|" % edition.license_pool.id
            else:
                print "    No Edition.license_pool."

            print "    Edition.title=%s|" % edition.title
            print "    Edition.author=%s|" % edition.author
            if (not edition.author_contributors):
                print "    NO Edition.author_contributor found"
            for acCount, author_contributor in enumerate(edition.author_contributors):
                print "    Edition.author_contributor[%s]=%s|" % (acCount, author_contributor)

        print "__________________________________________________________________\n"
        if (not data_sources):
            print "NO DataSource found"
        for index, data_source in enumerate(data_sources):
            print "DataSource[%s]=%s|" % (index, data_source)
            print "    DataSource.id=%s|" % data_source.id
            print "    DataSource.name=%s|" % data_source.name
            print "    DataSource.offers_licenses=%s|" % data_source.offers_licenses
            print "    DataSource.editions=%s|" % data_source.editions
            print "    DataSource.license_pools=%s|" % data_source.license_pools
            print "    DataSource.links=%s|" % data_source.links

        print "__________________________________________________________________\n"
        if (not representations):
            print "NO Representation found"
        for index, representation in enumerate(representations):
            print "Representation[%s]=%s|" % (index, representation)
            print "    Representation.id=%s|" % representation.id
            print "    Representation.url=%s|" % representation.url
            print "    Representation.mirror_url=%s|" % representation.mirror_url
            print "    Representation.fetch_exception=%s|" % representation.fetch_exception
            print "    Representation.mirror_exception=%s|" % representation.mirror_exception

        return


    def _library(self, name=None, short_name=None):
        name=name or self._str
        short_name = short_name or self._str
        library, ignore = get_one_or_create(
            self._db, Library, name=name, short_name=short_name,
            create_method_kwargs=dict(uuid=str(uuid.uuid4())),
        )
        return library

    def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT,
                    external_account_id=None, url=None, username=None,
                    password=None, data_source_name=None):
        name = name or self._str
        collection, ignore = get_one_or_create(
            self._db, Collection, name=name
        )
        collection.external_account_id = external_account_id
        integration = collection.create_external_integration(protocol)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        integration.url = url
        integration.username = username
        integration.password = password

        if data_source_name:
            collection.data_source = data_source_name
        return collection

    @property
    def _default_library(self):
        """A Library that will only be created once throughout a given test.

        By default, the `_default_collection` will be associated with
        the default library.
        """
        if not hasattr(self, '_default__library'):
            self._default__library = self.make_default_library(self._db)
        return self._default__library

    @property
    def _default_collection(self):
        """A Collection that will only be created once throughout
        a given test.

        For most tests there's no need to create a different
        Collection for every LicensePool. Using
        self._default_collection instead of calling self.collection()
        saves time.
        """
        if not hasattr(self, '_default__collection'):
            self._default__collection = self._default_library.collections[0]
        return self._default__collection

    @classmethod
    def make_default_library(cls, _db):
        """Ensure that the default library exists in the given database.

        This can be called by code intended for use in testing but not actually
        within a DatabaseTest subclass.
        """
        library, ignore = get_one_or_create(
            _db, Library, create_method_kwargs=dict(
                uuid=unicode(uuid.uuid4()),
                name="default",
            ), short_name="default"
        )
        collection, ignore = get_one_or_create(
            _db, Collection, name="Default Collection"
        )
        integration = collection.create_external_integration(
            ExternalIntegration.OPDS_IMPORT
        )
        integration.goal = ExternalIntegration.LICENSE_GOAL
        if collection not in library.collections:
            library.collections.append(collection)
        return library

    def _catalog(self, name=u"Faketown Public Library"):
        source, ignore = get_one_or_create(self._db, DataSource, name=name)

    def _integration_client(self, url=None, shared_secret=None):
        url = url or self._url
        secret = shared_secret or u"secret"
        return get_one_or_create(
            self._db, IntegrationClient, shared_secret=secret,
            create_method_kwargs=dict(url=url)
        )[0]

    def _subject(self, type, identifier):
        return get_one_or_create(
            self._db, Subject, type=type, identifier=identifier
        )[0]

    def _classification(self, identifier, subject, data_source, weight=1):
        return get_one_or_create(
            self._db, Classification, identifier=identifier, subject=subject,
            data_source=data_source, weight=weight
        )[0]

    def sample_cover_path(self, name):
        """The path to the sample cover with the given filename."""
        base_path = os.path.split(__file__)[0]
        resource_path = os.path.join(base_path, "tests", "files", "covers")
        sample_cover_path = os.path.join(resource_path, name)
        return sample_cover_path

    def sample_cover_representation(self, name):
        """A Representation of the sample cover with the given filename."""
        sample_cover_path = self.sample_cover_path(name)
        return self._representation(
            media_type="image/png", content=open(sample_cover_path).read())[0]
コード例 #46
0
ファイル: api.py プロジェクト: aikuyun/superset
def get_table_metadata(database: Database, table_name: str,
                       schema_name: Optional[str]) -> Dict:
    """
        Get table metadata information, including type, pk, fks.
        This function raises SQLAlchemyError when a schema is not found.


    :param database: The database model
    :param table_name: Table name
    :param schema_name: schema name
    :return: Dict table metadata ready for API response
    """
    keys: List = []
    columns = database.get_columns(table_name, schema_name)
    # define comment dict by tsl
    comment_dict = {}
    primary_key = database.get_pk_constraint(table_name, schema_name)
    if primary_key and primary_key.get("constrained_columns"):
        primary_key["column_names"] = primary_key.pop("constrained_columns")
        primary_key["type"] = "pk"
        keys += [primary_key]
    # get dialect name
    dialect_name = database.get_dialect().name
    if isinstance(dialect_name, bytes):
        dialect_name = dialect_name.decode()
    # get column comment, presto & hive
    if dialect_name == "presto" or dialect_name == "hive":
        db_engine_spec = database.db_engine_spec
        sql = ParsedQuery("desc {a}.{b}".format(a=schema_name,
                                                b=table_name)).stripped()
        engine = database.get_sqla_engine(schema_name)
        conn = engine.raw_connection()
        cursor = conn.cursor()
        query = Query()
        session = Session(bind=engine)
        query.executed_sql = sql
        query.__tablename__ = table_name
        session.commit()
        db_engine_spec.execute(cursor, sql, async_=False)
        data = db_engine_spec.fetch_data(cursor, query.limit)
        # parse list data into dict by tsl; hive and presto is different
        if dialect_name == "presto":
            for d in data:
                d[3]
                comment_dict[d[0]] = d[3]
        else:
            for d in data:
                d[2]
                comment_dict[d[0]] = d[2]
        conn.commit()

    foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
    indexes = get_indexes_metadata(database, table_name, schema_name)
    keys += foreign_keys + indexes
    payload_columns: List[Dict] = []
    for col in columns:
        dtype = get_col_type(col)
        if len(comment_dict) > 0:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                comment_dict[col["name"]],
            })
        elif dialect_name == "mysql":
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                col["comment"],
            })
        else:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                # "comment": col["comment"],
            })
    return {
        "name":
        table_name,
        "columns":
        payload_columns,
        "selectStar":
        database.select_star(
            table_name,
            schema=schema_name,
            show_cols=True,
            indent=True,
            cols=columns,
            latest_partition=True,
        ),
        "primaryKey":
        primary_key,
        "foreignKeys":
        foreign_keys,
        "indexes":
        keys,
    }
コード例 #47
0
    def manage_slas(self, dag: DAG, session: Session = None) -> None:
        """
        Finding all tasks that have SLAs defined, and sending alert emails
        where needed. New SLA misses are also recorded in the database.

        We are assuming that the scheduler runs often, so we only check for
        tasks that should have succeeded in the past hour.
        """
        self.log.info("Running SLA Checks for %s", dag.dag_id)
        if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
            self.log.info(
                "Skipping SLA check for %s because no tasks in DAG have SLAs",
                dag)
            return

        qry = (session.query(
            TI.task_id,
            func.max(TI.execution_date).label('max_ti')).with_hint(
                TI, 'USE INDEX (PRIMARY)',
                dialect_name='mysql').filter(TI.dag_id == dag.dag_id).filter(
                    or_(TI.state == State.SUCCESS,
                        TI.state == State.SKIPPED)).filter(
                            TI.task_id.in_(dag.task_ids)).group_by(
                                TI.task_id).subquery('sq'))

        max_tis: List[TI] = (session.query(TI).filter(
            TI.dag_id == dag.dag_id,
            TI.task_id == qry.c.task_id,
            TI.execution_date == qry.c.max_ti,
        ).all())

        ts = timezone.utcnow()
        for ti in max_tis:
            task = dag.get_task(ti.task_id)
            if task.sla and not isinstance(task.sla, timedelta):
                raise TypeError(
                    f"SLA is expected to be timedelta object, got "
                    f"{type(task.sla)} in {task.dag_id}:{task.task_id}")

            dttm = dag.following_schedule(ti.execution_date)
            while dttm < timezone.utcnow():
                following_schedule = dag.following_schedule(dttm)
                if following_schedule + task.sla < timezone.utcnow():
                    session.merge(
                        SlaMiss(task_id=ti.task_id,
                                dag_id=ti.dag_id,
                                execution_date=dttm,
                                timestamp=ts))
                dttm = dag.following_schedule(dttm)
        session.commit()

        # pylint: disable=singleton-comparison
        slas: List[SlaMiss] = (
            session.query(SlaMiss).filter(SlaMiss.notification_sent == False,
                                          SlaMiss.dag_id == dag.dag_id)  # noqa
            .all())
        # pylint: enable=singleton-comparison

        if slas:  # pylint: disable=too-many-nested-blocks
            sla_dates: List[datetime.datetime] = [
                sla.execution_date for sla in slas
            ]
            fetched_tis: List[TI] = (session.query(TI).filter(
                TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates),
                TI.dag_id == dag.dag_id).all())
            blocking_tis: List[TI] = []
            for ti in fetched_tis:
                if ti.task_id in dag.task_ids:
                    ti.task = dag.get_task(ti.task_id)
                    blocking_tis.append(ti)
                else:
                    session.delete(ti)
                    session.commit()

            task_list = "\n".join(sla.task_id + ' on ' +
                                  sla.execution_date.isoformat()
                                  for sla in slas)
            blocking_task_list = "\n".join(ti.task_id + ' on ' +
                                           ti.execution_date.isoformat()
                                           for ti in blocking_tis)
            # Track whether email or any alert notification sent
            # We consider email or the alert callback as notifications
            email_sent = False
            notification_sent = False
            if dag.sla_miss_callback:
                # Execute the alert callback
                self.log.info('Calling SLA miss callback')
                try:
                    dag.sla_miss_callback(dag, task_list, blocking_task_list,
                                          slas, blocking_tis)
                    notification_sent = True
                except Exception:  # pylint: disable=broad-except
                    self.log.exception(
                        "Could not call sla_miss_callback for DAG %s",
                        dag.dag_id)
            email_content = f"""\
            Here's a list of tasks that missed their SLAs:
            <pre><code>{task_list}\n<code></pre>
            Blocking tasks:
            <pre><code>{blocking_task_list}<code></pre>
            Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
            """

            tasks_missed_sla = []
            for sla in slas:
                try:
                    task = dag.get_task(sla.task_id)
                except TaskNotFound:
                    # task already deleted from DAG, skip it
                    self.log.warning(
                        "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.",
                        sla.task_id)
                    continue
                tasks_missed_sla.append(task)

            emails: Set[str] = set()
            for task in tasks_missed_sla:
                if task.email:
                    if isinstance(task.email, str):
                        emails |= set(get_email_address_list(task.email))
                    elif isinstance(task.email, (list, tuple)):
                        emails |= set(task.email)
            if emails:
                try:
                    send_email(emails,
                               f"[airflow] SLA miss on DAG={dag.dag_id}",
                               email_content)
                    email_sent = True
                    notification_sent = True
                except Exception:  # pylint: disable=broad-except
                    Stats.incr('sla_email_notification_failure')
                    self.log.exception(
                        "Could not send SLA Miss email notification for DAG %s",
                        dag.dag_id)
            # If we sent any notification, update the sla_miss table
            if notification_sent:
                for sla in slas:
                    sla.email_sent = email_sent
                    sla.notification_sent = True
                    session.merge(sla)
            session.commit()
コード例 #48
0
def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply],
                   session: Session, data_dir: str) -> None:
    """
    * Existing replies are updated in the local database.
    * New replies have an entry created in the local database.
    * Local replies not returned in the remote replies are deleted from the
      local database unless they are pending or failed.

    If a reply references a new journalist username, add them to the database
    as a new user.
    """
    local_uuids = {reply.uuid for reply in local_replies}
    for reply in remote_replies:
        if reply.uuid in local_uuids:
            local_reply = [r for r in local_replies if r.uuid == reply.uuid][0]

            user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session)
            local_reply.journalist_id = user.id
            local_reply.size = reply.size

            local_uuids.remove(reply.uuid)
            logger.debug('Updated reply {}'.format(reply.uuid))
        else:
            # A new reply to be added to the database.
            source_uuid = reply.source_uuid
            source = session.query(Source).filter_by(uuid=source_uuid)[0]
            user = find_or_create_user(
                reply.journalist_uuid,
                reply.journalist_username,
                session)

            nr = Reply(uuid=reply.uuid,
                       journalist_id=user.id,
                       source_id=source.id,
                       filename=reply.filename,
                       size=reply.size)
            session.add(nr)

            # All replies fetched from the server have succeeded in being sent,
            # so we should delete the corresponding draft locally if it exists.
            try:
                draft_reply_db_object = session.query(DraftReply).filter_by(
                    uuid=reply.uuid).one()

                update_draft_replies(session, draft_reply_db_object.source.id,
                                     draft_reply_db_object.timestamp,
                                     draft_reply_db_object.file_counter,
                                     nr.file_counter)
                session.delete(draft_reply_db_object)

            except NoResultFound:
                pass  # No draft locally stored corresponding to this reply.

            logger.debug('Added new reply {}'.format(reply.uuid))

    # The uuids remaining in local_uuids do not exist on the remote server, so
    # delete the related records.
    replies_to_delete = [r for r in local_replies if r.uuid in local_uuids]
    for deleted_reply in replies_to_delete:
        delete_single_submission_or_reply_on_disk(deleted_reply, data_dir)
        session.delete(deleted_reply)
        logger.debug('Deleted reply {}'.format(deleted_reply.uuid))

    session.commit()
コード例 #49
0
class testTradingCenter(unittest.TestCase):

    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

        currency = Currency(name='Pesos', code='ARG')
        self.exchange = Exchange(name='Merval', code='MERV', currency=currency)
        self.owner = Owner(name='poor owner')
        self.broker = Broker(name='broker1')
        self.account = Account(owner=self.owner, broker=self.broker)
        self.account.deposit(Money(amount=10000, currency=currency))

    def tearDown(self):
        self.trans.rollback()
        self.session.close()

    def test_open_orders_by_order_id(self):
        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        order=tc.open_order_by_id(order1.id)
        self.assertEquals(order1, order)

        order=tc.open_order_by_id(100)
        self.assertEquals(None, order)

    def testGetOpenOrdersBySymbol(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        orders=tc.open_orders_by_symbol('symbol')
        self.assertEquals([order1, order2], list(orders))

    def testCancelOrder(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        order1.cancel()
        self.assertEquals([order2], tc.open_orders)
        self.assertEquals([order1], tc.cancel_orders)
        self.assertEquals(CancelOrderStage, type(order1.current_stage))

        order2.cancel()
        self.assertEquals([], tc.open_orders)
        self.assertEquals([order1, order2], tc.cancel_orders)

    def testCancelAllOpenOrders(self):
        security=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=security, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=security, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        tc.cancel_all_open_orders()

        self.assertEquals([], tc.open_orders)

    def testConsume(self):
        pass

    def testPostConsume(self):
        pass

    def testCreateAccountWithMetrix(self):
        pass
コード例 #50
0
ファイル: per_session.py プロジェクト: gajop/ailadder
    from sqlalchemy.orm import sessionmaker
    from sqlalchemy.ext.declarative import declarative_base
    
    Session = sessionmaker(query_cls=CachingQuery)
    
    Base = declarative_base(engine=create_engine('sqlite://', echo=True))
    
    class User(Base):
        __tablename__ = 'users'
        id = Column(Integer, primary_key=True)
        name = Column(String(100))
        
        def __repr__(self):
            return "User(name=%r)" % self.name

    Base.metadata.create_all()
    
    sess = Session()
    
    sess.add_all(
        [User(name='u1'), User(name='u2'), User(name='u3')]
    )
    sess.commit()
    
    # cache two user objects
    sess.query(User).with_cache_key('u2andu3').filter(User.name.in_(['u2', 'u3'])).all()
    
    # pull straight from cache
    print sess.query(User).with_cache_key('u2andu3').all()
    
コード例 #51
0
#!/usr/bin/python3
""" update state with id = 2 from the database hbtn_0e_6_usa"""

from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from model_state import Base, State
import sys

if __name__ == '__main__':
    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        sys.argv[1], sys.argv[2], sys.argv[3]),
                           pool_pre_ping=True)
    Base.metadata.create_all(engine)

    session = Session(engine)
    query_row = session.query(State).filter(State.id == 2).\
        update({"name": "New Mexico"}, synchronize_session="fetch")
    session.commit()
    session.close()
コード例 #52
0
ファイル: purge.py プロジェクト: devhub/baph
    def handle_noargs(self, **options):
        verbosity = 1 #int(options.get('verbosity'))
        interactive = options.get('interactive')
        show_traceback = options.get('traceback')

        self.style = no_style()

        # Import the 'management' module within each installed app, to register
        # dispatcher events.
        for app_name in settings.INSTALLED_APPS:
            try:
                import_module('.management', app_name)
            except ImportError as exc:
                # This is slightly hackish. We want to ignore ImportErrors
                # if the "management" module itself is missing -- but we don't
                # want to ignore the exception if the management module exists
                # but raises an ImportError for some reason. The only way we
                # can do this is to check the text of the exception. Note that
                # we're a bit broad in how we check the text, because different
                # Python implementations may not use the same text.
                # CPython uses the text "No module named management"
                # PyPy uses "No module named myproject.myapp.management"
                msg = exc.args[0]
                if not msg.startswith('No module named') or 'management' not in msg:
                    raise

        db = options.get('database')
        orm = ORM.get(db)
        db_info = orm.settings_dict
        is_test_db = db_info.get('TEST', False)
        if not is_test_db:
            print 'Database "%s" cannot be purged because it is not a test ' \
                  'database.\nTo flag this as a test database, set TEST to ' \
                  'True in the database settings.' % db
            sys.exit()

        if interactive:
            confirm = raw_input('\nYou have requested a purge of database ' \
                '"%s" (%s). This will IRREVERSIBLY DESTROY all data ' \
                'currently in the database, and DELETE ALL TABLES AND ' \
                'SCHEMAS. Are you sure you want to do this?\n\n' \
                'Type "yes" to continue, or "no" to cancel: ' \
                % (db, orm.engine.url))
        else:
            confirm = 'yes'

        if confirm == 'yes':
            # get a list of all schemas used by the app
            default_schema = orm.engine.url.database
            app_schemas = set(orm.Base.metadata._schemas)
            app_schemas.add(default_schema)

            url = deepcopy(orm.engine.url)
            url.database = None
            engine = create_engine(url)
            inspector = inspect(engine)

            # get a list of existing schemas
            db_schemas = set(inspector.get_schema_names())

            schemas = app_schemas.intersection(db_schemas)

            app_tables = set()
            for table in orm.Base.metadata.tables.values():
                schema = table.schema or default_schema
                app_tables.add('%s.%s' % (schema, table.name))

            metadata = MetaData()
            db_tables = []
            all_fks = []

            for schema in schemas:
                for table_name in inspector.get_table_names(schema):
                    fullname = '%s.%s' % (schema, table_name)
                    if fullname not in app_tables:
                        continue
                    fks = []
                    for fk in inspector.get_foreign_keys(table_name, schema=schema):
                        if not fk['name']:
                            continue
                        fks.append(ForeignKeyConstraint((),(),name=fk['name']))
                    t = Table(table_name, metadata, *fks, schema=schema)
                    db_tables.append(t)
                    all_fks.extend(fks)

            session = Session(bind=engine)
            for fkc in all_fks:
                session.execute(DropConstraint(fkc))
            for table in db_tables:
                session.execute(DropTable(table))
            for schema in schemas:
                session.execute(DropSchema(schema))
            session.commit()
            session.bind.dispose()

        else:
            self.stdout.write("Purge cancelled.\n")
コード例 #53
0
def remove(db_session: Session, shopping_list: models.List):
    db_session.delete(shopping_list)
    db_session.commit()
コード例 #54
0
ファイル: session.py プロジェクト: GEverding/inbox
class InboxSession(object):
    """ Inbox custom ORM (with SQLAlchemy compatible API).

    Parameters
    ----------
    engine : <sqlalchemy.engine.Engine>
        A configured database engine to use for this session
    versioned : bool
        Do you want to enable the transaction log?
    ignore_soft_deletes : bool
        Whether or not to ignore soft-deleted objects in query results.
    namespace_id : int
        Namespace to limit query results with.
    """
    def __init__(self, engine, versioned=True, ignore_soft_deletes=True,
                 namespace_id=None):
        # TODO: support limiting on namespaces
        assert engine, "Must set the database engine"

        args = dict(bind=engine, autoflush=True, autocommit=False)
        self.ignore_soft_deletes = ignore_soft_deletes
        if ignore_soft_deletes:
            args['query_cls'] = InboxQuery
        self._session = Session(**args)

        if versioned:
            from inbox.models.transaction import create_revisions

            @event.listens_for(self._session, 'after_flush')
            def after_flush(session, flush_context):
                """
                Hook to log revision snapshots. Must be post-flush in order to
                grab object IDs on new objects.
                """
                create_revisions(session)

    def query(self, *args, **kwargs):
        q = self._session.query(*args, **kwargs)
        if self.ignore_soft_deletes:
            return q.options(IgnoreSoftDeletesOption())
        else:
            return q

    def add(self, instance):
        if not self.ignore_soft_deletes or not instance.is_deleted:
            self._session.add(instance)
        else:
            raise Exception("Why are you adding a deleted object?")

    def add_all(self, instances):
        if True not in [i.is_deleted for i in instances] or \
                not self.ignore_soft_deletes:
            self._session.add_all(instances)
        else:
            raise Exception("Why are you adding a deleted object?")

    def delete(self, instance):
        if self.ignore_soft_deletes:
            instance.mark_deleted()
            # just to make sure
            self._session.add(instance)
        else:
            self._session.delete(instance)

    def begin(self):
        self._session.begin()

    def commit(self):
        self._session.commit()

    def rollback(self):
        self._session.rollback()

    def flush(self):
        self._session.flush()

    def close(self):
        self._session.close()

    def expunge(self, obj):
        self._session.expunge(obj)

    def merge(self, obj):
        return self._session.merge(obj)

    @property
    def no_autoflush(self):
        return self._session.no_autoflush
コード例 #55
0
ファイル: processor.py プロジェクト: abhinavkumar195/airflow
    def process_file(
        self,
        file_path: str,
        callback_requests: List[CallbackRequest],
        pickle_dags: bool = False,
        session: Session = NEW_SESSION,
    ) -> Tuple[int, int]:
        """
        Process a Python file containing Airflow DAGs.

        This includes:

        1. Execute the file and look for DAG objects in the namespace.
        2. Execute any Callbacks if passed to this method.
        3. Serialize the DAGs and save it to DB (or update existing record in the DB).
        4. Pickle the DAG and save it to the DB (if necessary).
        5. Mark any DAGs which are no longer present as inactive
        6. Record any errors importing the file into ORM

        :param file_path: the path to the Python file that should be executed
        :param callback_requests: failure callback to execute
        :param pickle_dags: whether serialize the DAGs found in the file and
            save them to the db
        :param session: Sqlalchemy ORM Session
        :return: number of dags found, count of import errors
        :rtype: Tuple[int, int]
        """
        self.log.info("Processing file %s for tasks to queue", file_path)

        try:
            dagbag = DagBag(file_path, include_examples=False)
        except Exception:
            self.log.exception("Failed at reloading the DAG file %s",
                               file_path)
            Stats.incr('dag_file_refresh_error', 1, 1)
            return 0, 0

        if len(dagbag.dags) > 0:
            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(),
                          file_path)
        else:
            self.log.warning("No viable dags retrieved from %s", file_path)
            self.update_import_errors(session, dagbag)
            if callback_requests:
                # If there were callback requests for this file but there was a
                # parse error we still need to progress the state of TIs,
                # otherwise they might be stuck in queued/running for ever!
                self.execute_callbacks_without_dag(callback_requests, session)
            return 0, len(dagbag.import_errors)

        self.execute_callbacks(dagbag, callback_requests, session)
        session.commit()

        # Save individual DAGs in the ORM
        dagbag.sync_to_db(session)
        session.commit()

        if pickle_dags:
            paused_dag_ids = DagModel.get_paused_dag_ids(
                dag_ids=dagbag.dag_ids)

            unpaused_dags: List[DAG] = [
                dag for dag_id, dag in dagbag.dags.items()
                if dag_id not in paused_dag_ids
            ]

            for dag in unpaused_dags:
                dag.pickle(session)

        # Record import errors into the ORM
        try:
            self.update_import_errors(session, dagbag)
        except Exception:
            self.log.exception("Error logging import errors!")

        # Record DAG warnings in the metadatabase.
        try:
            self.update_dag_warnings(session=session, dagbag=dagbag)
        except Exception:
            self.log.exception("Error logging DAG warnings.")

        return len(dagbag.dags), len(dagbag.import_errors)
コード例 #56
0
ファイル: db_1_1_0.py プロジェクト: adamscieszko/rhodecode
 def delete(cls, id_):
     obj = cls.query().get(id_)
     Session.delete(obj)
     Session.commit()
コード例 #57
0
ファイル: test_sqlalchemy.py プロジェクト: polyactis/test
#2008-05-07
s = sqlalchemy.sql.select([Article.c.body, Article.c.headline], Article.c.headline=='Python is cool!')
#connection = eng.connect()
result = connection.execute(s)
print "Fetched from Article table before commit"
for row in result:
	print row
#"""

print dir(Article)
#rollback all table savings
yes_or_no = raw_input("Commit Database Transaction?(y/n):")
yes_or_no = yes_or_no.lower()
#default is rollback. so no need to take care 'n' or 'no'.
if yes_or_no=='y' or yes_or_no=='yes':
	session.commit()
	#transaction.commit()	#it will also execute session.flush() if it's not executed.
else:
	session.rollback()
	#transaction.rollback()

aa = articles.alias()
s = sqlalchemy.sql.select([aa.c.body, aa.c.headline, association.c.article_id], aa.c.article_id==association.c.article_id, order_by=[association.c.article_id])
#connection = eng.connect()
result = connection.execute(s).fetchmany(3)
print "Fetched from Article table"
for row in result:
	print row.headline
	print row['headline']
	print row
コード例 #58
0
class AlchemyFeatureDatabase(FeatureDatabase):

	"""
		Uses SQLAlchemy to provide a a FeatureDatabase backed by an SQL database.
		IMPORTANT: all insertions must be followed by a finalize() call
	"""

	def __init__(self, engine):
		"""
			Initializes this feature database using an engine
		"""

		if type(engine) == types.StringType:
			new_engine = create_engine(engine, encoding='utf-8')
			if "sqlite:" in engine:
				new_engine.raw_connection().connection.text_factory = str 
			engine = new_engine
		self._engine = engine
		self._session = Session(bind=self._engine)
		self._cache = None
		self._cache_dirty = False
		self.initialize()

	def initialize(self):
		Base.metadata.create_all(self._engine)

	def finalize(self):
		self._session.commit()

	def _get_session(self):
		return self._session 

	def add_feature(self, feature):
		session = self._get_session()
		feature = Feature(feature)
		session.add(feature)
		return True

	def get_feature(self, qry):
		session = self._get_session()
		for obj in session.query(Feature).filter_by(feature=qry):
			return obj

	def feature_exists(self, feature):
		qry = self.get_feature(feature)
		return qry != None

	def add_source(self, source):
		session = self._get_session()
		source = Source(source)
		session.add(source)
		return True

	def get_source(self, qry):
		session = self._get_session()
		for obj in session.query(Source).filter_by(source=qry):
			return obj

	def _gen_cache(self):
		self._cache = {}
		cache = self._cache
		session = self._get_session()
		for obj in session.query(Example)\
			.options(joinedload('feature'))\
			.options(joinedload('source')):
			if obj.source.source not in cache:
				cache[obj.source.source] = {}
			sub_cache = cache[obj.source.source]
			feature = obj.feature.feature 
			if type(feature) is not types.StringType and type(feature) is not types.UnicodeType:
					sub_cache_key = str(obj.feature.feature)
			else:
				sub_cache_key = feature
			if sub_cache_key not in sub_cache:
				sub_cache[sub_cache_key] = [(obj.label, obj.extra)]
			else:
				sub_cache[sub_cache_key].append((obj.label, obj.extra))
		self._cache_dirty = False


	def add_feature_example(self, feature, label, source=None, extra = None):
		
		super(AlchemyFeatureDatabase, self).add_feature_example(feature, label, source, extra)

		if source is None:
			source = "default"

		if label not in [-1, 0, 1]:
			raise ValueError("label: must be -1, 0 or 1")

		source_obj = self.get_source(source)
		feature_obj = self.get_feature(feature)

		if source_obj is None:
			self.add_source(source)
			return self.add_feature_example(feature, label, source, extra)
		if feature_obj is None:
			self.add_feature(feature)
			return self.add_feature_example(feature, label, source, extra)

		session = self._get_session()

		example_obj = Example(feature_obj, label, source_obj, extra)

		session.add(example_obj)
		self._cache_dirty = True
		return True

	def get_feature_examples(self, feature, sources=None):

		if self._cache is None or self._cache_dirty:
			self._gen_cache()

		if type(feature) is not types.StringType:
			if type(feature) is not types.UnicodeType:
				feature = str(feature)

		for source in self._cache:
			if sources is not None:
				if source not in sources:
					continue
			sc = self._cache[source]
			if feature not in sc:
				continue 
			for pair in sc[feature]:
				yield pair

	def get_all_features(self):
		if self._cache is None or self._cache_dirty:
			self._gen_cache()

		for source in self._cache:
			sc = self._cache[source]
			for feature in sc:
				for label, extra in sc[feature]:
					yield feature, label, extra
コード例 #59
0
def populate_for_human_evaluation(session: Session,
                                  method_to_result: Dict[str, Path]) -> None:

    if session.query(HumanEvaluation).first() is not None:
        return

    if session.query(GenerationResult).first() is not None:
        return

    method_names = list(method_to_result.keys()) + ['Gold']

    token_delim = '|'

    with method_to_result['Base'].open(mode='r') as f:
        N = sum(1 for _ in f) - 1
        f.seek(0)

        reader = csv.reader(f,
                            delimiter=',',
                            quotechar='"',
                            quoting=csv.QUOTE_ALL)
        next(reader)
        for _ in tqdm(range(N)):
            fields = next(reader)
            article_id = fields[0]
            base_result = GenerationResult(
                article_id, 'Base',
                ''.join(number2kansuuzi(fields[4].split(token_delim)[:-1])))
            session.add(base_result)

            gold_result = GenerationResult(article_id, 'Gold', None)
            session.add(gold_result)

        session.commit()

    for method_name in [m for m in method_names if m not in ['Base', 'Gold']]:

        if not method_to_result[method_name].is_file():
            continue

        with method_to_result[method_name].open(mode='r') as f:
            N = sum(1 for _ in f) - 1
            f.seek(0)

            reader = csv.reader(f,
                                delimiter=',',
                                quotechar='"',
                                quoting=csv.QUOTE_ALL)
            next(reader)
            for _ in tqdm(range(N)):
                fields = next(reader)
                article_id = fields[0]
                result = GenerationResult(
                    article_id, method_name, ''.join(
                        number2kansuuzi(fields[4].split(token_delim)[:-1])))
                session.add(result)
        session.commit()

    groups = session \
        .query(GenerationResult.article_id) \
        .filter(GenerationResult.result != '') \
        .group_by(GenerationResult.article_id) \
        .all()

    orderings = list(permutations([m for m in method_names]))

    for group in groups:

        i = random.randint(0, math.factorial(len(method_names)) - 1)
        h = HumanEvaluation(group.article_id, ordering=orderings[i])
        session.add(h)

    session.commit()

    results = session.query(HumanEvaluation).all()

    SAMPLE_SIZE = 100
    n_results = len(results)

    if n_results >= SAMPLE_SIZE:
        sample_indices = random.sample(range(n_results), SAMPLE_SIZE)

        for i, result in enumerate(results):
            result.is_target = i in sample_indices

        session.commit()
コード例 #60
0
ファイル: be.py プロジェクト: jjhoo/putiikki
class Catalog(object):
    def __init__(self, engine):
        self.engine = engine
        self.session = Session(bind=self.engine)

    @subtransaction
    def add_item(self, code, description, long_description=None):
        citem = models.Item(code=code,
                            description=description,
                            long_description=long_description)
        self.session.add(citem)
        return citem

    @subtransaction
    def add_items(self, items):
        for item in items:
            # KeyErrors not caught if missing required field

            try:
                long_desc = item['long description']
            except KeyError:
                long_desc = None

            citem = models.Item(code=item['code'],
                                description=item['description'],
                                description_lower=item['description'].lower(),
                                long_description=long_desc)

            primary=True
            for cat in item['categories']:
                q = self.session.query(models.Category).\
                  filter(models.Category.name == cat)
                cater = q.first()

                icater = models.ItemCategory(item=citem, category=cater,
                                             primary=primary)
                citem.categories.append(icater)
                primary=False
            self.session.add(citem)

    def add_category(self, name):
        self.add_categories([name])

    @subtransaction
    def add_categories(self, names):
        for name in names:
            cater = models.Category(name=name)
            self.session.add(cater)

    # def remove_category: should make sure that zero use
    # def rename_category

    @subtransaction
    def add_item_category(self, item_id, category_id, primary=False):
        catitem = models.ItemCategory(item_id=item_id, category_id=category_id,
                                      primary=primary)
        self.session.add(catitem)
        return catitem

    # def remove_item_category

    def get_item(self, code, as_object=False):
        q = self.session.query(models.Item).filter(models.Item.code == code)
        item = q.first()
        if item is None:
            return None
        if as_object is True:
            return item
        return (item.id, item.code, item.description, item.long_description)

    def remove_item(self, code):
        self.session.begin(subtransactions=True)
        q = self.session.query(models.Item).filter(models.Item.code == code).\
          delete()
        self.session.commit()

    @subtransaction
    def update_item(self, code, description=None, long_description=None,
                    new_code=None):
        q = self.session.query(models.Item).filter(models.Item.code == code)
        item = q.first()

        if new_code is not None:
            item.code = new_code

        if description is not None:
            item.description = description
            item.description_lower = description.lower()

        if long_description is not None:
            item.long_description = long_description

    def get_stock(self, code, as_object=False):
        q = self.session.query(models.Item, models.StockItem).\
          with_entities(models.StockItem).\
          filter(models.Item.code == code).\
          filter(models.Item.id == models.StockItem.item_id)
        res = q.first()
        if res is None:
            return None

        if as_object is True:
            return res
        return (res.id, res.price, res.count)

    @subtransaction
    def add_stock(self, items):

        for item in items:
            q = self.session.query(models.Item).\
              filter(models.Item.code == item['code'])

            citem = q.first()
            stock = models.StockItem(item=citem, count=item['count'],
                                    price=item['price'])
            self.session.add(stock)

    @subtransaction
    def update_stock(self, code, count, price):
        q = self.session.query(models.Item, models.StockItem).\
              with_entities(models.StockItem).\
              join(models.StockItem.item).\
              filter(models.Item.code == code)

        stock = q.first()
        if stock is None:
            q = self.session.query(models.Item).\
              filter(models.Item.code == code)
            item = q.first()
            if item is not None:
                stock = models.StockItem(item=item, count=count, price=price)
                self.session.add(stock)
            else:
                raise KeyError('Unknown item code')
        else:
            stock.count += count
            stock.price = price

    @subtransaction
    def add_items_with_stock(self, items):
        categories = set()
        for item in items:
            for x in item['categories']:
                categories.add(x)
        self.add_categories(categories)

        self.add_items(items)
        self.add_stock(items)

    def list_items(self, sort_key='description',
                   ascending=True, page=1, page_size=10):
        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()
        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
          with_entities(models.Item.code, models.Item.description,
                        models.Category.name, models.StockItem.price,
                        models.StockItem.count, sq.c.reserved).\
          join(models.StockItem).\
          join(models.ItemCategory,
               models.Item.id == models.ItemCategory.item_id).\
          join(models.Category,
               models.Category.id == models.ItemCategory.category_id).\
          filter(models.ItemCategory.primary == True).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id)

        q = ordering(q, ascending, sort_key)
        q = pagination(q, page, page_size)

        res = [item_to_json(*x) for x in q]
        return res

    def search_items(self, prefix, price_range, sort_key='description',
                     ascending=True, page=1, page_size=10):
        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()

        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
          with_entities(models.Item.code, models.Item.description,
                        models.Category.name, models.StockItem.price,
                        models.StockItem.count,
                        sq.c.reserved).\
          join(models.StockItem).\
          join(models.ItemCategory,
               models.Item.id == models.ItemCategory.item_id).\
          join(models.Category,
               models.Category.id == models.ItemCategory.category_id).\
          filter(models.StockItem.price.between(*price_range),
                 models.Item.description_lower.like('{:s}%'.format(prefix.lower())),
                 models.ItemCategory.primary == True).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id)

        q = ordering(q, ascending, sort_key)
        q = pagination(q, page, page_size)

        res = [item_to_json(*x) for x in q]
        return res

    def list_items_by_prices(self, prices, sort_key='price', prefix=None,
                             ascending=True, page=1, page_size=10):
        pgs = pg_cases(prices)
        pg_case = sqla.case(pgs, else_ = -1).label('price_group')

        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()

        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
                               with_entities(pg_case, models.Item.code,
                                             models.Item.description,
                                             models.Category.name,
                                             models.StockItem.price,
                                             models.StockItem.count,
                                             sq.c.reserved)
        if prefix is not None:
            q = q.filter(models.Item.description_lower.like('{:s}%'.format(prefix.lower())))

        q = q.join(models.StockItem.item).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id).\
          filter(models.ItemCategory.item_id == models.Item.id,
                 models.ItemCategory.category_id == models.Category.id,
                 models.ItemCategory.primary == True,
                 pg_case >= 0)

        q = pg_ordering(q, ascending)
        q = pagination(q, page, page_size)
        def to_dict(x):
            tmp = item_to_json(*x[1:])
            tmp.update({'price_group': x[0]})
            return tmp
        res = [to_dict(x) for x in q]
        return res

    def _get_reservations(self, stock_id):
        q = self.session.query(func.sum(models.Reservation.count)).\
          filter(models.StockItem.id == stock_id).\
          filter(models.StockItem.id == models.Reservation.stock_item_id)

        res = q.first()
        if res is None or res[0] is None:
            return 0

        return res[0]

    def _get_reservation(self, basket_item_id):
        q = self.session.query(models.Basket, models.BasketItem,
                               models.Reservation).\
          with_entities(models.Reservation).\
          join(models.Reservation.basket_item).\
          filter(models.BasketItem.id == basket_item_id)
        res = q.first()
        return res

    @subtransaction
    def _update_reservation(self, stock, basket_item):
        reservations = self._get_reservations(basket_item.stock_item_id)
        # can reserve (scount - reservations)
        reservation = self._get_reservation(basket_item.id)

        if reservation is not None:
            rcount = min(basket_item.count,
                         stock.count - reservations + reservation.count)
            reservation.count = rcount
        else:
            rcount = min(basket_item.count, stock.count - reservations)
            reservation = models.Reservation(stock_item=stock,
                                             basket_item=basket_item,
                                             count=rcount)
            self.session.add(reservation)