def upgrade(): from sqlalchemy.orm.session import Session session = Session(bind=op.get_bind()) # Find duplicates for vcs in session.query(Vcs).group_by(Vcs.repository_id, Vcs.revision).having(func.count(Vcs.id) > 1).all(): print(vcs) # Find all vcs entries with this duplication dupes = session.query(Vcs).filter(Vcs.repository_id == vcs.repository_id).filter(Vcs.revision == vcs.revision).all() # Keep the first and remove the others - thus we need to update references to others to the first for update in dupes[1:]: for af in session.query(Artifakt).filter(Artifakt.vcs_id == update.id).all(): print("Updating artifakt {} to point to vcs {}".format(af.sha1, dupes[0].id)) af.vcs_id = dupes[0].id print("Deleting vcs {}".format(update.id)) session.delete(update) session.commit() if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = OFF") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 0") else: raise NotImplemented with op.batch_alter_table('vcs', schema=None) as batch_op: batch_op.create_unique_constraint('rr', ['repository_id', 'revision']) if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = ON") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 1")
class DBBackedController(object): def __init__(self, engine, session=None): if type(engine) == types.StringType: logging.info("Using connection string '%s'" % (engine,)) new_engine = create_engine(engine, encoding='utf-8') if "sqlite:" in engine: logging.debug("Setting text factory for unicode compat.") new_engine.raw_connection().connection.text_factory = str self._engine = new_engine else: logging.info("Using existing engine...") self._engine = engine if session is None: logging.info("Binding session...") self._session = Session(bind=self._engine) else: self._session = session logging.info("Updating metadata...") Base.metadata.create_all(self._engine) def commit(self): logging.info("Commiting...") self._session.commit()
def init_settings(session: Session): set_new_setting(session, "ffmpeg_path", "ffmpeg") set_new_setting(session, "ffmpeg_verbose", False) set_new_setting(session, "input_directory", "/set/me") set_new_setting(session, "output_directory", "esa_videos") set_new_setting(session, "final_output_directory", "esa_final_videos") set_new_setting(session, "common_keywords", "ESA") set_new_setting(session, "input_video_extension", ".flv") set_new_setting(session, "output_video_extension", ".mp4") set_new_setting(session, "output_video_prefix", "esa_") set_new_setting(session, "upload_logfile", "esa_upload.log") set_new_setting(session, "youtube_public", False) set_new_setting(session, "youtube_schedule", False) set_new_setting(session, "youtube_schedule_start_ts", 0) set_new_setting(session, "schedule_offset_hours", 12) set_new_setting(session, "upload_to_twitch", True) set_new_setting(session, "upload_to_youtube", True) set_new_setting(session, "delete_after_upload", True) set_new_setting(session, "upload_ratelimit", 2.5) set_new_setting(session, "upload_retry_count", 3) set_new_setting(session, "twitch_ul_pid", 0) set_new_setting(session, "twitch_ul_create_time", 0.0) set_new_setting(session, "youtube_ul_pid", 0) set_new_setting(session, "youtube_ul_create_time", 0.0) set_new_setting(session, "proc_ul_pid", 0) set_new_setting(session, "proc_ul_create_time", 0.0) set_new_setting(session, "init_proc_status", "") session.commit()
def post_process_video(session: Session, debug): vids = session.query(Video) vids = vids.filter(Video.processing_done) vids = vids.filter(not_(Video.post_processing_done)) vids = vids.filter(or_(and_(Video.do_twitch, Video.done_twitch), not_(Video.do_twitch))) vids = vids.filter(or_(and_(Video.do_youtube, Video.done_youtube), not_(Video.do_youtube))) vids = vids.order_by(Video.id.asc()) vid = vids.first() if not vid: print("No video in need of processing found") return 0 out_dir = get_setting(session, "output_directory") final_dir = get_setting(session, "final_output_directory") out_prefix = get_setting(session, "output_video_prefix") out_ext = get_setting(session, "output_video_extension") out_fname = "%s%s%s" % (out_prefix, vid.id, out_ext) out_path = os.path.join(out_dir, out_fname) final_path = os.path.join(final_dir, out_fname) if out_path != final_path: shutil.move(out_path, final_path) vid.post_processing_status = "Moved %s to %s" % (out_path, final_path) else: vid.post_processing_status = "Nothing to do" vid.post_processing_done = True session.commit() return 0
class TestSessions(TestCase): plugins = [] def test_multiple_connections(self): self.session2 = Session(bind=self.engine.connect()) article = self.Article(name=u'Session1 article') article2 = self.Article(name=u'Session2 article') self.session.add(article) self.session2.add(article2) self.session.flush() self.session2.flush() self.session.commit() self.session2.commit() assert article.versions[-1].transaction_id == 1 assert article2.versions[-1].transaction_id == 2 def test_manual_transaction_creation(self): uow = versioning_manager.unit_of_work(self.session) transaction = uow.create_transaction(self.session) self.session.flush() assert transaction.id == 1 article = self.Article(name=u'Session1 article') self.session.add(article) self.session.flush() assert uow.current_transaction.id == 1 self.session.commit() assert article.versions[-1].transaction_id == 1 def test_commit_without_objects(self): self.session.commit()
def create_session(self): session = Session(self.connection) try: yield session session.commit() except Exception as ex: session.rollback() raise ex finally: session.close()
class VMDatabase(object): def __init__(self, path=os.path.expanduser("~/vm.db")): engine = create_engine('sqlite:///%s' % path, echo=False) metadata.create_all(engine) Session = sessionmaker(bind=engine, autoflush=True, autocommit=False) self.session = Session() def print_state(self): for provider in self.getProviders(): print 'Provider:', provider.name for base_image in provider.base_images: print ' Base image:', base_image.name for snapshot_image in base_image.snapshot_images: print ' Snapshot:', snapshot_image.name, \ snapshot_image.state for machine in base_image.machines: print ' Machine:', machine.id, machine.name, \ machine.state, machine.state_time, machine.ip def abort(self): self.session.rollback() def commit(self): self.session.commit() def delete(self, obj): self.session.delete(obj) def getProviders(self): return self.session.query(Provider).all() def getProvider(self, name): return self.session.query(Provider).filter_by(name=name)[0] def getResult(self, id): return self.session.query(Result).filter_by(id=id)[0] def getMachine(self, id): return self.session.query(Machine).filter_by(id=id)[0] def getMachineByJenkinsName(self, name): return self.session.query(Machine).filter_by(jenkins_name=name)[0] def getMachineForUse(self, image_name): """Atomically find a machine that is ready for use, and update its state.""" image = None for machine in self.session.query(Machine).filter( machine_table.c.state == READY).order_by( machine_table.c.state_time): if machine.base_image.name == image_name: machine.state = USED self.commit() return machine raise Exception("No machine found for image %s" % image_name)
def upgrade(): context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry).filter_by(customer_request_id=None): for trac in tp.project.tracs: cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone() sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone() tp.customer_request_id = sql_cr.id print sql_cr.id session.commit()
def upgrade(): ### commands auto generated by Alembic - please adjust! ### op.add_column('time_entries', sa.Column('tickettitle', sa.Unicode(), nullable=True)) context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry): for trac in tp.project.tracs: ticket = session.execute('select summary from "trac_%s".ticket where id=%s' % (trac.trac_name, tp.ticket)).fetchone() tp.tickettitle = ticket.summary session.commit()
def wrapper(*args, **kwargs): ret = func(*args, **kwargs) session = Session() # tmp session._model_changes = {} try: session.commit() except SQLAlchemyError as se: session.rollback() raise_server_exc(DATABASE_UNKNOWN_ERROR, exc=se) finally: session.close() return ret
def upgrade(): context = op.get_context() session = Session() session.bind = context.bind session.execute( """ALTER TABLE applications DROP CONSTRAINT "applications_project_id_fkey", ADD CONSTRAINT "applications_project_id_fkey" foreign key (project_id) references projects(id) on update cascade; ALTER TABLE customer_requests DROP CONSTRAINT "customer_requests_project_id_fkey", ADD CONSTRAINT "customer_requests_project_id_fkey" foreign key (project_id) references projects(id) on update cascade; ALTER TABLE groups DROP CONSTRAINT "groups_project_id_fkey", ADD CONSTRAINT "groups_project_id_fkey" foreign key (project_id) references projects(id) on update cascade; ALTER TABLE contracts ADD CONSTRAINT "contracts_project_id_fkey" foreign key (project_id) references projects(id) on update cascade; ALTER TABLE kanban_projects DROP CONSTRAINT "kanban_projects_project_id_fkey", ADD CONSTRAINT "kanban_projects_project_id_fkey" foreign key (project_id) references projects(id) on update cascade on delete cascade; ALTER TABLE favorite_projects DROP CONSTRAINT "favorite_projects_project_id_fkey", ADD CONSTRAINT "favorite_projects_project_id_fkey" foreign key (project_id) references projects(id) on update cascade; ALTER TABLE time_entries DROP CONSTRAINT "time_entries_project_id_fkey", ADD CONSTRAINT "time_entries_project_id_fkey" foreign key (project_id) references projects(id) on update cascade;""") session.commit()
def delete_video(session: Session, vid_id: int): vid = session.query(Video).get(vid_id) if not vid: print("Video ID not found") return -1 print("Deleting video %s: %s" % (vid.id, vid.title)) session.delete(vid) session.commit() print("Done") return 0
def create_session(self): """ Create a session context that communicates with the database. Commits all changes to the database before closing the session, and if an exception is raised, rollback the session. """ session = Session(self.connection) try: yield session session.commit() except Exception as ex: session.rollback() raise ex finally: session.close()
def upgrade(): context = op.get_context() session = Session() session.bind = context.bind for b in session.query(KanbanBoard): if b.json: columns = loads(b.json) for n, column in enumerate(columns): for m, task in enumerate(column['tasks']): new_id = update_id(task['id']) print task['id'], new_id if new_id != task['id']: columns[n]['tasks'][m]['id'] = new_id b.json = dumps(columns) session.commit()
def upgrade(): ### commands auto generated by Alembic - please adjust! ### op.add_column('time_entries', sa.Column('customer_request_id', sa.String(), nullable=True)) op.drop_column('time_entries', u'contract_id') op.drop_column('customer_requests', u'placement') context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry): for trac in tp.project.tracs: cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone() sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone() tp.customer_request_id = sql_cr.id session.commit()
def transaction(cls, session: db_session.Session) -> Iterator[db_session.Session]: """ Provides a transactional context-based database session. Args: session: The database session instance to wrap Yields: The wrapped database session instance """ try: yield session session.commit() except: session.rollback() raise finally: session.close()
def setup_module(): global transaction, connection, engine engine = create_engine('postgresql:///recordsheet_test') connection = engine.connect() transaction = connection.begin() Base.metadata.create_all(connection) #insert some data inner_tr = connection.begin_nested() ses = Session(connection) # ses.begin_nested() ses.add(dbmodel.Account(name='TEST01', desc='test account 01')) ses.add(dbmodel.Account(name='TEST02', desc='test account 02')) user = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=False) ses.add(user) lockeduser = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=True) ses.add(lockeduser) batch = dbmodel.Batch(user=user) ses.add(batch) jrnl = dbmodel.Journal(memo='test', batch=batch, datetime='2016-06-05 14:09:00-05') ses.add(jrnl) ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1, journal=jrnl)) ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2, journal=jrnl)) ses.commit() # mock a sessionmaker so all querys are in this transaction dbapi._session = lambda: ses ses.begin_nested() @event.listens_for(ses, "after_transaction_end") def restart_savepoint(session, transaction): if transaction.nested and not transaction._parent.nested: ses.begin_nested()
def setup_databases(self, **kwargs): # import all models to populate orm.metadata for app in settings.INSTALLED_APPS: import_any_module(['%s.models' % app], raise_error=False) # determine which schemas we need default_schema = orm.engine.url.database schemas = set(t.schema or default_schema \ for t in Base.metadata.tables.values()) url = deepcopy(orm.engine.url) url.database = None self.engine = create_engine(url) insp = inspect(self.engine) # get a list of already-existing schemas existing_schemas = set(insp.get_schema_names()) # if any of the needed schemas exist, do not proceed conflicts = schemas.intersection(existing_schemas) if conflicts: for c in conflicts: print 'drop schema %s;' % c sys.exit('The following schemas are already present: %s. ' \ 'TestRunner cannot proceeed' % ','.join(conflicts)) # create schemas session = Session(bind=self.engine) for schema in schemas: session.execute(CreateSchema(schema)) session.commit() session.bind.dispose() # create tables if len(orm.Base.metadata.tables) > 0: orm.Base.metadata.create_all(checkfirst=False) # generate permissions call_command('createpermissions') return schemas
def create_session(self): """ Create a session context that communicates with the database. Commits all changes to the database before closing the session, and if an exception is raised, rollback the session. """ session = Session(self.connection) logger.info("Create session {0} with {1}".format( id(session), self._public_db_uri(str(self.engine.url)))) try: yield session session.commit() logger.info("Commit transactions to database") except Exception: session.rollback() logger.exception("Database transactions rolled back") finally: logger.info("Session {0} with {1} closed".format( id(session), self._public_db_uri(str(self.engine.url)))) session.close()
def upgrade(migrate_engine): """ Upgrade operations go here. Don't create your own engine; bind migrate_engine to your metadata """ #========================================================================== # USER LOGS #========================================================================== from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog tbl = UserLog.__table__ username = Column("username", String(255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) # create username column username.create(table=tbl) _Session = Session() ## after adding that column fix all usernames users_log = _Session.query(UserLog)\ .options(joinedload(UserLog.user))\ .options(joinedload(UserLog.repository)).all() for entry in users_log: entry.username = entry.user.username _Session.add(entry) _Session.commit() #alter username to not null from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog tbl_name = UserLog.__tablename__ tbl = Table(tbl_name, MetaData(bind=migrate_engine), autoload=True, autoload_with=migrate_engine) col = tbl.columns.username # remove nullability from revision field col.alter(nullable=False)
def upgrade(): bind = op.get_bind() session = Session(bind=bind) session.query(ConfigModel).filter( ConfigModel.group == 'station', ConfigModel.name == 'required_fields').update({ 'value': pickle.dumps([ 'code', 'start_date', 'latitude', 'longitude', 'creation_date', 'elevation' ]) }) session.query(ConfigModel).filter( ConfigModel.group == 'channel', ConfigModel.name == 'required_fields').update({ 'value': pickle.dumps([ 'code', 'start_date', 'latitude', 'longitude', 'location_code', 'elevation', 'depth' ]) }) session.commit()
async def create(self, db: Session, *, order_in: OrderCreate) -> Order: new_order = Order(business_id=order_in.business_id) db.add(new_order) db.commit() db.flush(new_order) item_orders: List[ItemOrder] = [] for item in order_in.items_in_order: item_db = await crud.item.get(db, id=item.item_id) if not item_db: raise HTTPException(status_code=404, detail=f"Item {item.item_id} not found") # Create a new relation ItemOrder item_order = ItemOrder( tot_price=item.tot_price, quantity=item.quantity, order_id=new_order.id, item_id=item_db.id, ) item_orders.append(item_order) new_order.items = item_orders db.add(new_order) db.commit() db.flush(new_order) return new_order
def create(session: Session, state: TaskState, timestamp: datetime, repo_dir: str, task_type: TaskType, params: dict = None) -> Task: """ Creates a new Task record. Args: session (Session): The database session. state (TaskState): The task state. timestamp (datetime): The time when the task was created. repo_dir (str): The task repository. task_type (TaskType): The task type. params (str, optional): The task parameters. Defaults to None. Raises: ValueError: Thrown when missing status, timestamp or repo_dir. TaskAlreadyExistsError: Thrown when already exists a task with the same task_id. Returns: Task: The task. """ if state is None or not timestamp or not repo_dir or task_type is None: raise ValueError( "You cannot create a task without a status, a timestamp and a repo_dir." ) try: params_1: str = json.dumps(params) task: Task = Task(state.value, timestamp, repo_dir, task_type.value, params_1) session.add(task) session.commit() return task except IntegrityError as err: raise TaskAlreadyExistsError from err
def update( cls, session: Session, id: int, params: typing.Dict[str, typing.Any], commit: bool = True, autoflush: bool = False, ) -> DeclarativeMeta: obj = cls.get(session, id) if obj: for param in params: setattr(obj, param, params[param]) if commit: session.commit() session.refresh(obj) if autoflush: session.flush() return obj
def create_comment(cls, praw_comment: PrawComment, post: Post, session: Session, download_session_id: int, parent_comment_id: Optional[int] = None): if cls.check_duplicate_comment(praw_comment.id, session): author = cls.get_author(praw_comment, session) subreddit = cls.get_subreddit(praw_comment, session) comment = Comment(author=author, subreddit=subreddit, post=post, reddit_id=praw_comment.id, body=praw_comment.body, body_html=praw_comment.body_html, score=praw_comment.score, date_posted=datetime.fromtimestamp( praw_comment.created), parent_id=parent_comment_id, download_session_id=download_session_id) session.add(comment) session.commit() return comment return None
async def send_message_to_user(self, user_id: UUID, sender_id: UUID, message: str, session: Session = None): db_message = ChatMessage(from_user_id=sender_id, to_user_id=user_id, message=message) session.add(db_message) session.commit() from_chatuser = self.chat_users.get(sender_id) to_chatuser = self.chat_users.get(user_id) event = ChatEvent( event_type=ChatEventType.MESSAGE, data={ "from": str(sender_id), "to": str(user_id), "message": message }, ) if from_chatuser: await from_chatuser.send_json(event.dict(by_alias=True)) if to_chatuser: await to_chatuser.send_json(event.dict(by_alias=True))
def updateHardwareProfile(self, session: Session, hardwareProfileObject: HardwareProfile) -> None: """ Update Hardware Profile Object """ try: dbHardwareProfile = \ self._hardwareProfilesDbHandler.getHardwareProfileById( session, hardwareProfileObject.getId()) self.__populateHardwareProfile(session, hardwareProfileObject, dbHardwareProfile) self._set_tags(dbHardwareProfile, hardwareProfileObject.getTags()) session.commit() except TortugaException: session.rollback() raise except Exception as ex: session.rollback() self.getLogger().exception('%s' % ex) raise
def test_check_auth_redis_miss_wrong_auth_id(self): db_handler = self.auth_handler.db_handler db_conn = db_handler.getEngine().connect() db_txn = db_conn.begin() try: db_session = Session(bind=db_conn) try: account = Account(auth_id='some_auth_id', username='******') db_session.add(account) db_session.flush() phonenumber = PhoneNumber(number='9740171794', account_id=account.id) db_session.add(phonenumber) db_session.commit() self.auth_handler.redis_client.hget.return_value = None status, phonenums = self.auth_handler._check_auth( 'some_user', 'faulty_auth_id', db_session) self.assertFalse(status) self.assertEquals(None, phonenums) finally: db_session.close() finally: db_txn.rollback() db_conn.close()
def migrate_archived_page_urls(session: Session, cur, table_archives: str): t0 = time.time() logger.info('Migrating %s', table_archives) n_in = 0 n_out = 0 cur.execute( f'SELECT feed_url, archived_url, date, inserted FROM {table_archives}') for feed_url, archived_url, date, inserted in cur: n_in += 1 logger.debug('%d %s', n_in, archived_url) if not count( session.query(ArchivedPageURL).filter( ArchivedPageURL.feed_url == feed_url, ArchivedPageURL.archived_url == archived_url, ArchivedPageURL.date == date, )): archived_page_url = ArchivedPageURL( feed_url=feed_url, archived_url=archived_url, date=date, inserted=inserted, ) session.add(archived_page_url) if n_in % 10000 == 0: logger.info('%d flush', n_in) session.flush() n_out += 1 logger.info('commit') session.commit() logger.info( 'Migrated %s: %d -> %d in %ds', table_archives, n_in, n_out, time.time() - t0, )
def upgrade(): from sqlalchemy.orm.session import Session session = Session(bind=op.get_bind()) # Add dummy name where there is none for repo in session.query(Repository).all(): if repo.name == "": repo.name = "NoName" session.commit() if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = OFF") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 0") else: raise NotImplemented with op.batch_alter_table('repository', schema=None) as batch_op: batch_op.create_check_constraint('non_empty_name', 'name != ""') if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = ON") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 1")
def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.add_column( "project_task_type_link", sa.Column("created_at", sa.DateTime(), nullable=True), ) op.add_column( "project_task_type_link", sa.Column( "id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), default=uuid.uuid4, nullable=True, ), ) op.execute("UPDATE project_task_type_link SET id = '%s'" % fields.gen_uuid()) op.add_column( "project_task_type_link", sa.Column("updated_at", sa.DateTime(), nullable=True), ) op.add_column( "project_task_type_link", sa.Column("priority", sa.Integer(), nullable=True), ) session = Session(bind=op.get_bind()) for link in session.query(ProjectTaskTypeLink).all(): link.id = fields.gen_uuid() session.add(link) session.commit() op.alter_column("project_task_type_link", "id", nullable=False) op.create_unique_constraint( "project_tasktype_uc", "project_task_type_link", ["project_id", "task_type_id"], )
def deleteHardwareProfile(self, session: Session, name: str) -> None: """ Delete hardwareProfile from the db. Returns: None Throws: HardwareProfileNotFound DbError TortugaException """ try: hwProfile = self._hardwareProfilesDbHandler.getHardwareProfile( session, name) if hwProfile.nodes: raise TortugaException( 'Unable to delete hardware profile with associated nodes') # First delete the mappings hwProfile.mappedsoftwareprofiles = [] self._logger.debug('Marking hardware profile [%s] for deletion' % (name)) session.delete(hwProfile) session.commit() except TortugaException: session.rollback() raise except Exception as ex: session.rollback() self._logger.exception(str(ex)) raise
def migration_session(): session = Session(bind=op.get_bind()) task_queue = RollbackableQueue() try: yield session, task_queue print("Commiting SQL session...") session.commit() print("Session commit succeeded.") except Exception as e: print(f"Encountered exception: {e.__class__}\n{e}") print("Running SQL rollback...") session.rollback() print("SQL rollback succeeded.") if task_queue: try: print("Running GCS rollback...") task_queue.rollback() print("GCS rollback succeeded.") except Exception as e: print(f"GCS rollback failed: {e.__class__}\n{e}") raise finally: session.close()
def update_import_errors(session: Session, dagbag: DagBag) -> None: """ For the DAGs in the given DagBag, record any associated import errors and clears errors for files that no longer have them. These are usually displayed through the Airflow UI so that users know that there are issues parsing DAGs. :param session: session for ORM operations :type session: sqlalchemy.orm.session.Session :param dagbag: DagBag containing DAGs with import errors :type dagbag: airflow.DagBag """ # Clear the errors of the processed files for dagbag_file in dagbag.file_last_changed: session.query(errors.ImportError).filter( errors.ImportError.filename.startswith(dagbag_file)).delete( synchronize_session="fetch") # Add the errors of the processed files for filename, stacktrace in dagbag.import_errors.items(): session.add( errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)) session.commit()
def deleteKit(self, session: Session, name, version, iteration, force=False): """ Delete kit from the db. Raises: KitNotFound KitInUse DbError """ try: self._kitsDbHandler.deleteKit( session, name, version, iteration, force=force) session.commit() except TortugaException: session.rollback() raise except Exception as ex: session.rollback() self._logger.exception(str(ex)) raise
def create(session: Session, repo_dir: str, issue_id: int, author: str, title: str, description: str, labels: list, is_pull_request: bool) -> Issue: """ Creates a new Issue record. Args: session (Session): The database session. repo_dir (str): The issue repository. issue_id (int): The issue identifier. author (str): The issue author. title (str): The issue title. description (str): The issue description. labels (list): The issue labels. is_pull_request (bool): If true the issue is a pull request, otherwise false. Raises: ValueError: Thrown when missing repo_dir, issue_id, author or title. IssueAlreadyExistError: Thrown when already exists an issuewith the same issue_id and repo_dir. Returns: Issue: The issue. """ if not repo_dir or issue_id is None or not author or not title: raise ValueError("You cannot create an issue without a repo_dir, " "an issue_id, an author and a title.") try: labels_1: str = json.dumps(labels) issue: Issue = Issue(repo_dir, issue_id, author, title, description, labels_1, is_pull_request) session.add(issue) session.commit() return issue except IntegrityError as err: session.rollback() raise IssueAlreadyExistsError from err
def add_token(request: LocalProxy, session: Session) -> str: """ Create a new token for the user or return a valid existing token to the user. """ token = None id_ = int(request.authorization['username']) try: token = session.query(Token).filter(Token.user_id == id_).one() present = datetime.now() present = present - token.timestamp if present > timedelta(0, 0, 0, 0, 45, 0, 0): update_token = '%030x' % randrange(16**30) token.id = update_token token.timestamp = datetime.now() session.commit() except NoResultFound: token = '%030x' % randrange(16**30) time = datetime.now() new_token = Token(user_id=id_, id=token, timestamp=time) session.add(new_token) session.commit() return token return token.id
def addPartition(self, session: Session, partitionName: str, softwareProfileName: str) -> None: """ Add software profile partition. Returns: partitionId Throws: PartitionAlreadyExists SoftwareProfileNotFound """ try: self._softwareProfilesDbHandler.addPartitionToSoftwareProfile( session, partitionName, softwareProfileName) session.commit() except TortugaException: session.rollback() raise except Exception as ex: session.rollback() self._logger.exception(str(ex)) raise
def addCompartmentalizedComponent(componentId, compartmentId): session = Session() if not session.query(CompartmentalizedComponent).filter( CompartmentalizedComponent.component_id == componentId).filter( CompartmentalizedComponent.compartment_id == compartmentId).count(): try: component = session.query(Component).filter( Component.id == componentId).one() except: print "model does not exist in database" raise try: compartment = session.query(Compartment).filter( Compartment.id == compartmentId).one() except: print "compartment does not exist in database" raise cc = CompartmentalizedComponent(component_id=component.id, compartment_id=compartment.id) session.add(cc) session.commit() session.close() return cc
def addReactionMatrix(reactionId, compartmentalizedComponentId): session = Session() if not session.query(ReactionMatrix).filter( ReactionMatrix.reaction_id == reactionId).filter( ReactionMatrix.compartmentalized_component_id == compartmentalizedComponentId).count(): try: cc = session.query(CompartmentalizedComponent).filter( cc.id == compartmentalizedComponentId).one() except: print "compartmentalized component does not exist in database" raise try: reaction = session.query(Reaction).filter( Reaction.id == reactionId).one() except: print "reaction does not exist in database" raise rm = ReactionMatrix(reaction_id=reaction.id, compartmentalized_component_id=cc.id) session.add(rm) session.commit() session.close() return rm
def process_file(session: Session, file: File, callback: Callable[[], None]) -> bool: if file.processing_started_at: return False # Claim this file by updating the `processing_started_at` timestamp in such # a way that it must not have been set before. processing_started_at = datetime.datetime.now(timezone.utc) result = session.execute( update(File.__table__) # pylint: disable=no-member .where(File.id == file.id).where( File.processing_started_at.is_(None)).values( processing_started_at=processing_started_at)) if result.rowcount == 0: return False # If we got this far, `file` is ours to process. try: session.begin_nested() callback() file.processing_started_at = processing_started_at file.processing_completed_at = datetime.datetime.now(timezone.utc) session.add(file) session.commit() return True except Exception as error: session.rollback() file.processing_started_at = processing_started_at file.processing_completed_at = datetime.datetime.now(timezone.utc) # Some errors stringify nicely, some don't (e.g. StopIteration) so we # have to format them. file.processing_error = str(error) or str( traceback.format_exception(error.__class__, error, error.__traceback__)) if not isinstance(error, UserError): raise error return True
def upgrade(): session = Session(bind=op.get_bind()) changes = session.query(Sticker).filter(Sticker.file_id.is_(None)).delete() session.commit()
class DatabaseTest(object): engine = None connection = None @classmethod def get_database_connection(cls): url = Configuration.database_url() engine, connection = SessionManager.initialize(url) return engine, connection @classmethod def setup_class(cls): # Initialize a temporary data directory. cls.engine, cls.connection = cls.get_database_connection() cls.old_data_dir = Configuration.data_directory cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp") Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir # Avoid CannotLoadConfiguration errors related to CDN integrations. Configuration.instance[ Configuration.INTEGRATIONS] = Configuration.instance.get( Configuration.INTEGRATIONS, {}) Configuration.instance[Configuration.INTEGRATIONS][ ExternalIntegration.CDN] = {} @classmethod def teardown_class(cls): # Destroy the database connection and engine. cls.connection.close() cls.engine.dispose() if cls.tmp_data_dir.startswith("/tmp"): logging.debug("Removing temporary directory %s" % cls.tmp_data_dir) shutil.rmtree(cls.tmp_data_dir) else: logging.warn( "Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir) Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir def setup(self, mock_search=True): # Create a new connection to the database. self._db = Session(self.connection) self.transaction = self.connection.begin_nested() # Start with a high number so it won't interfere with tests that search for an age or grade self.counter = 2000 self.time_counter = datetime(2014, 1, 1) self.isbns = [ "9780674368279", "0636920028468", "9781936460236", "9780316075978" ] if mock_search: self.search_mock = mock.patch( external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex) self.search_mock.start() else: self.search_mock = None # TODO: keeping this for now, but need to fix it bc it hits _isbn, # which pops an isbn off the list and messes tests up. so exclude # _ functions from participating. # also attempt to stop nosetest showing docstrings instead of function names. #for name, obj in inspect.getmembers(self): # if inspect.isfunction(obj) and obj.__name__.startswith('test_'): # obj.__doc__ = None def teardown(self): # Close the session. self._db.close() # Roll back all database changes that happened during this # test, whether in the session that was just closed or some # other session. self.transaction.rollback() # Remove any database objects cached in the model classes but # associated with the now-rolled-back session. Collection.reset_cache() ConfigurationSetting.reset_cache() DataSource.reset_cache() DeliveryMechanism.reset_cache() ExternalIntegration.reset_cache() Genre.reset_cache() Library.reset_cache() # Also roll back any record of those changes in the # Configuration instance. for key in [ Configuration.SITE_CONFIGURATION_LAST_UPDATE, Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE ]: if key in Configuration.instance: del (Configuration.instance[key]) if self.search_mock: self.search_mock.stop() def shortDescription(self): return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2. @property def _id(self): self.counter += 1 return self.counter @property def _str(self): return unicode(self._id) @property def _time(self): v = self.time_counter self.time_counter = self.time_counter + timedelta(days=1) return v @property def _isbn(self): return self.isbns.pop() @property def _url(self): return "http://foo.com/" + self._str def _patron(self, external_identifier=None, library=None): external_identifier = external_identifier or self._str library = library or self._default_library return get_one_or_create(self._db, Patron, external_identifier=external_identifier, library=library)[0] def _contributor(self, sort_name=None, name=None, **kw_args): name = sort_name or name or self._str return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args) def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None): if foreign_id: id = foreign_id else: id = self._str return Identifier.for_foreign_id(self._db, identifier_type, id)[0] def _edition(self, data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID, with_license_pool=False, with_open_access_download=False, title=None, language="eng", authors=None, identifier_id=None, series=None, collection=None): id = identifier_id or self._str source = DataSource.lookup(self._db, data_source_name) wr = Edition.for_foreign_id(self._db, source, identifier_type, id)[0] if not title: title = self._str wr.title = unicode(title) wr.medium = Edition.BOOK_MEDIUM if series: wr.series = series if language: wr.language = language if authors is None: authors = self._str if isinstance(authors, basestring): authors = [authors] if authors != []: wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE) wr.author = unicode(authors[0]) for author in authors[1:]: wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE) if with_license_pool or with_open_access_download: pool = self._licensepool( wr, data_source_name=data_source_name, with_open_access_download=with_open_access_download, collection=collection) pool.set_presentation_edition() return wr, pool return wr def _work(self, title=None, authors=None, genre=None, language=None, audience=None, fiction=True, with_license_pool=False, with_open_access_download=False, quality=0.5, series=None, presentation_edition=None, collection=None, data_source_name=None): """Create a Work. For performance reasons, this method does not generate OPDS entries or calculate a presentation edition for the new Work. Tests that rely on this information being present should call _slow_work() instead, which takes more care to present the sort of Work that would be created in a real environment. """ pools = [] if with_open_access_download: with_license_pool = True language = language or "eng" title = unicode(title or self._str) audience = audience or Classifier.AUDIENCE_ADULT if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name: # TODO: This is necessary because Gutenberg's childrens books # get filtered out at the moment. data_source_name = DataSource.OVERDRIVE elif not data_source_name: data_source_name = DataSource.GUTENBERG if fiction is None: fiction = True new_edition = False if not presentation_edition: new_edition = True presentation_edition = self._edition( title=title, language=language, authors=authors, with_license_pool=with_license_pool, with_open_access_download=with_open_access_download, data_source_name=data_source_name, series=series, collection=collection, ) if with_license_pool: presentation_edition, pool = presentation_edition if with_open_access_download: pool.open_access = True pools = [pool] else: pools = presentation_edition.license_pools work, ignore = get_one_or_create(self._db, Work, create_method_kwargs=dict( audience=audience, fiction=fiction, quality=quality), id=self._id) if genre: if not isinstance(genre, Genre): genre, ignore = Genre.lookup(self._db, genre, autocreate=True) work.genres = [genre] work.random = 0.5 work.set_presentation_edition(presentation_edition) if pools: # make sure the pool's presentation_edition is set, # bc loan tests assume that. if not work.license_pools: for pool in pools: work.license_pools.append(pool) for pool in pools: pool.set_presentation_edition() # This is probably going to be used in an OPDS feed, so # fake that the work is presentation ready. work.presentation_ready = True work.calculate_opds_entries(verbose=False) return work def add_to_materialized_view(self, works, true_opds=False): """Make sure all the works in `works` show up in the materialized view. :param true_opds: Generate real OPDS entries for each each work, rather than faking it. """ if not isinstance(works, list): works = [works] for work in works: if true_opds: work.calculate_opds_entries(verbose=False) else: work.presentation_ready = True work.simple_opds_entry = "<entry>an entry</entry>" self._db.commit() SessionManager.refresh_materialized_views(self._db) def _lane(self, display_name=None, library=None, parent=None, genres=None, languages=None, fiction=None): display_name = display_name or self._str library = library or self._default_library lane, is_new = get_one_or_create( self._db, Lane, library=library, parent=parent, display_name=display_name, create_method_kwargs=dict(fiction=fiction)) if is_new and parent: lane.priority = len(parent.sublanes) - 1 if genres: if not isinstance(genres, list): genres = [genres] for genre in genres: if isinstance(genre, basestring): genre, ignore = Genre.lookup(self._db, genre) lane.genres.append(genre) if languages: if not isinstance(languages, list): languages = [languages] lane.languages = languages return lane def _slow_work(self, *args, **kwargs): """Create a work that closely resembles one that might be found in the wild. This is significantly slower than _work() but more reliable. """ work = self._work(*args, **kwargs) work.calculate_presentation_edition() work.calculate_opds_entries(verbose=False) return work def _add_generic_delivery_mechanism(self, license_pool): """Give a license pool a generic non-open-access delivery mechanism.""" data_source = license_pool.data_source identifier = license_pool.identifier content_type = Representation.EPUB_MEDIA_TYPE drm_scheme = DeliveryMechanism.NO_DRM LicensePoolDeliveryMechanism.set(data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT) def _coverage_record( self, edition, coverage_source, operation=None, status=CoverageRecord.SUCCESS, collection=None, exception=None, ): if isinstance(edition, Identifier): identifier = edition else: identifier = edition.primary_identifier record, ignore = get_one_or_create(self._db, CoverageRecord, identifier=identifier, data_source=coverage_source, operation=operation, collection=collection, create_method_kwargs=dict( timestamp=datetime.utcnow(), status=status, exception=exception, )) return record def _work_coverage_record(self, work, operation=None, status=CoverageRecord.SUCCESS): record, ignore = get_one_or_create(self._db, WorkCoverageRecord, work=work, operation=operation, create_method_kwargs=dict( timestamp=datetime.utcnow(), status=status, )) return record def _licensepool(self, edition, open_access=True, data_source_name=DataSource.GUTENBERG, with_open_access_download=False, set_edition_as_presentation=False, collection=None): source = DataSource.lookup(self._db, data_source_name) if not edition: edition = self._edition(data_source_name) collection = collection or self._default_collection pool, ignore = get_one_or_create( self._db, LicensePool, create_method_kwargs=dict(open_access=open_access), identifier=edition.primary_identifier, data_source=source, collection=collection, availability_time=datetime.utcnow()) if set_edition_as_presentation: pool.presentation_edition = edition if with_open_access_download: pool.open_access = True url = "http://foo.com/" + self._str media_type = MediaTypes.EPUB_MEDIA_TYPE link, new = pool.identifier.add_link( Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type) # Add a DeliveryMechanism for this download pool.set_delivery_mechanism( media_type, DeliveryMechanism.NO_DRM, RightsStatus.GENERIC_OPEN_ACCESS, link.resource, ) representation, is_new = self._representation(url, media_type, "Dummy content", mirrored=True) link.resource.representation = representation else: # Add a DeliveryMechanism for this licensepool pool.set_delivery_mechanism(MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, RightsStatus.UNKNOWN, None) pool.licenses_owned = pool.licenses_available = 1 return pool def _representation(self, url=None, media_type=None, content=None, mirrored=False): url = url or "http://foo.com/" + self._str repr, is_new = get_one_or_create(self._db, Representation, url=url) repr.media_type = media_type if media_type and content: repr.content = content repr.fetched_at = datetime.utcnow() if mirrored: repr.mirror_url = "http://foo.com/" + self._str repr.mirrored_at = datetime.utcnow() return repr, is_new def _customlist(self, foreign_identifier=None, name=None, data_source_name=DataSource.NYT, num_entries=1, entries_exist_as_works=True): data_source = DataSource.lookup(self._db, data_source_name) foreign_identifier = foreign_identifier or self._str now = datetime.utcnow() customlist, ignore = get_one_or_create( self._db, CustomList, create_method_kwargs=dict( created=now, updated=now, name=name or self._str, description=self._str, ), data_source=data_source, foreign_identifier=foreign_identifier) editions = [] for i in range(num_entries): if entries_exist_as_works: work = self._work(with_open_access_download=True) edition = work.presentation_edition else: edition = self._edition(data_source_name, title="Item %s" % i) edition.permanent_work_id = "Permanent work ID %s" % self._str customlist.add_entry(edition, "Annotation %s" % i, first_appearance=now) editions.append(edition) return customlist, editions def _complaint(self, license_pool, type, source, detail, resolved=None): complaint, is_new = Complaint.register(license_pool, type, source, detail, resolved) return complaint def _credential(self, data_source_name=DataSource.GUTENBERG, type=None, patron=None): data_source = DataSource.lookup(self._db, data_source_name) type = type or self._str patron = patron or self._patron() credential, is_new = Credential.persistent_token_create( self._db, data_source, type, patron) return credential def _external_integration(self, protocol, goal=None, settings=None, libraries=None, **kwargs): integration = None if not libraries: integration, ignore = get_one_or_create(self._db, ExternalIntegration, protocol=protocol, goal=goal) else: if not isinstance(libraries, list): libraries = [libraries] # Try to find an existing integration for one of the given # libraries. for library in libraries: integration = ExternalIntegration.lookup(self._db, protocol, goal, library=libraries[0]) if integration: break if not integration: # Otherwise, create a brand new integration specifically # for the library. integration = ExternalIntegration( protocol=protocol, goal=goal, ) integration.libraries.extend(libraries) self._db.add(integration) for attr, value in kwargs.items(): setattr(integration, attr, value) settings = settings or dict() for key, value in settings.items(): integration.set_setting(key, value) return integration def _delegated_patron_identifier( self, library_uri=None, patron_identifier=None, identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, identifier=None): """Create a sample DelegatedPatronIdentifier""" library_uri = library_uri or self._url patron_identifier = patron_identifier or self._str if callable(identifier): make_id = identifier else: if not identifier: identifier = self._str def make_id(): return identifier patron, is_new = DelegatedPatronIdentifier.get_one_or_create( self._db, library_uri, patron_identifier, identifier_type, make_id) return patron def _sample_ecosystem(self): """ Creates an ecosystem of some sample work, pool, edition, and author objects that all know each other. """ # make some authors [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob") bob.family_name, bob.display_name = bob.default_names() [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice") alice.family_name, alice.display_name = alice.default_names() edition_std_ebooks, pool_std_ebooks = self._edition( DataSource.STANDARD_EBOOKS, Identifier.URI, with_license_pool=True, with_open_access_download=True, authors=[]) edition_std_ebooks.title = u"The Standard Ebooks Title" edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle" edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_git.title = u"The GItenberg Title" edition_git.subtitle = u"The GItenberg Subtitle" edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE) edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_gut.title = u"The GUtenberg Title" edition_gut.subtitle = u"The GUtenberg Subtitle" edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE) work = self._work(presentation_edition=edition_git) for p in pool_gut, pool_std_ebooks: work.license_pools.append(p) work.calculate_presentation() return (work, pool_std_ebooks, pool_git, pool_gut, edition_std_ebooks, edition_git, edition_gut, alice, bob) def print_database_instance(self): """ Calls the class method that examines the current state of the database model (whether it's been committed or not). NOTE: If you set_trace, and hit "continue", you'll start seeing console output right away, without waiting for the whole test to run and the standard output section to display. You can also use nosetest --nocapture. I use: def test_name(self): [code...] set_trace() self.print_database_instance() # TODO: remove before prod [code...] """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn( "Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production." ) return DatabaseTest.print_database_class(self._db) return @classmethod def print_database_class(cls, db_connection): """ Prints to the console the entire contents of the database, as the unit test sees it. Exists because unit tests don't persist db information, they create a memory representation of the db state, and then roll the unit test-derived transactions back. So we cannot see what's going on by going into postgres and running selects. This is the in-test alternative to going into postgres. Can be called from model and metadata classes as well as tests. NOTE: The purpose of this method is for debugging. Be careful of leaving it in code and potentially outputting vast tracts of data into your output stream on production. Call like this: set_trace() from testing import ( DatabaseTest, ) _db = Session.object_session(self) DatabaseTest.print_database_class(_db) # TODO: remove before prod """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn( "Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production." ) return works = db_connection.query(Work).all() identifiers = db_connection.query(Identifier).all() license_pools = db_connection.query(LicensePool).all() editions = db_connection.query(Edition).all() data_sources = db_connection.query(DataSource).all() representations = db_connection.query(Representation).all() if (not works): print "NO Work found" for wCount, work in enumerate(works): # pipe character at end of line helps see whitespace issues print "Work[%s]=%s|" % (wCount, work) if (not work.license_pools): print " NO Work.LicensePool found" for lpCount, license_pool in enumerate(work.license_pools): print " Work.LicensePool[%s]=%s|" % (lpCount, license_pool) print " Work.presentation_edition=%s|" % work.presentation_edition print "__________________________________________________________________\n" if (not identifiers): print "NO Identifier found" for iCount, identifier in enumerate(identifiers): print "Identifier[%s]=%s|" % (iCount, identifier) print " Identifier.licensed_through=%s|" % identifier.licensed_through print "__________________________________________________________________\n" if (not license_pools): print "NO LicensePool found" for index, license_pool in enumerate(license_pools): print "LicensePool[%s]=%s|" % (index, license_pool) print " LicensePool.work_id=%s|" % license_pool.work_id print " LicensePool.data_source_id=%s|" % license_pool.data_source_id print " LicensePool.identifier_id=%s|" % license_pool.identifier_id print " LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id print " LicensePool.superceded=%s|" % license_pool.superceded print " LicensePool.suppressed=%s|" % license_pool.suppressed print "__________________________________________________________________\n" if (not editions): print "NO Edition found" for index, edition in enumerate(editions): # pipe character at end of line helps see whitespace issues print "Edition[%s]=%s|" % (index, edition) print " Edition.primary_identifier_id=%s|" % edition.primary_identifier_id print " Edition.permanent_work_id=%s|" % edition.permanent_work_id if (edition.data_source): print " Edition.data_source.id=%s|" % edition.data_source.id print " Edition.data_source.name=%s|" % edition.data_source.name else: print " No Edition.data_source." if (edition.license_pool): print " Edition.license_pool.id=%s|" % edition.license_pool.id else: print " No Edition.license_pool." print " Edition.title=%s|" % edition.title print " Edition.author=%s|" % edition.author if (not edition.author_contributors): print " NO Edition.author_contributor found" for acCount, author_contributor in enumerate( edition.author_contributors): print " Edition.author_contributor[%s]=%s|" % ( acCount, author_contributor) print "__________________________________________________________________\n" if (not data_sources): print "NO DataSource found" for index, data_source in enumerate(data_sources): print "DataSource[%s]=%s|" % (index, data_source) print " DataSource.id=%s|" % data_source.id print " DataSource.name=%s|" % data_source.name print " DataSource.offers_licenses=%s|" % data_source.offers_licenses print " DataSource.editions=%s|" % data_source.editions print " DataSource.license_pools=%s|" % data_source.license_pools print " DataSource.links=%s|" % data_source.links print "__________________________________________________________________\n" if (not representations): print "NO Representation found" for index, representation in enumerate(representations): print "Representation[%s]=%s|" % (index, representation) print " Representation.id=%s|" % representation.id print " Representation.url=%s|" % representation.url print " Representation.mirror_url=%s|" % representation.mirror_url print " Representation.fetch_exception=%s|" % representation.fetch_exception print " Representation.mirror_exception=%s|" % representation.mirror_exception return def _library(self, name=None, short_name=None): name = name or self._str short_name = short_name or self._str library, ignore = get_one_or_create( self._db, Library, name=name, short_name=short_name, create_method_kwargs=dict(uuid=str(uuid.uuid4())), ) return library def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT, external_account_id=None, url=None, username=None, password=None, data_source_name=None): name = name or self._str collection, ignore = get_one_or_create(self._db, Collection, name=name) collection.external_account_id = external_account_id integration = collection.create_external_integration(protocol) integration.goal = ExternalIntegration.LICENSE_GOAL integration.url = url integration.username = username integration.password = password if data_source_name: collection.data_source = data_source_name return collection @property def _default_library(self): """A Library that will only be created once throughout a given test. By default, the `_default_collection` will be associated with the default library. """ if not hasattr(self, '_default__library'): self._default__library = self.make_default_library(self._db) return self._default__library @property def _default_collection(self): """A Collection that will only be created once throughout a given test. For most tests there's no need to create a different Collection for every LicensePool. Using self._default_collection instead of calling self.collection() saves time. """ if not hasattr(self, '_default__collection'): self._default__collection = self._default_library.collections[0] return self._default__collection @classmethod def make_default_library(cls, _db): """Ensure that the default library exists in the given database. This can be called by code intended for use in testing but not actually within a DatabaseTest subclass. """ library, ignore = get_one_or_create(_db, Library, create_method_kwargs=dict( uuid=unicode(uuid.uuid4()), name="default", ), short_name="default") collection, ignore = get_one_or_create(_db, Collection, name="Default Collection") integration = collection.create_external_integration( ExternalIntegration.OPDS_IMPORT) integration.goal = ExternalIntegration.LICENSE_GOAL if collection not in library.collections: library.collections.append(collection) return library def _catalog(self, name=u"Faketown Public Library"): source, ignore = get_one_or_create(self._db, DataSource, name=name) def _integration_client(self, url=None, shared_secret=None): url = url or self._url secret = shared_secret or u"secret" return get_one_or_create(self._db, IntegrationClient, shared_secret=secret, create_method_kwargs=dict(url=url))[0] def _subject(self, type, identifier): return get_one_or_create(self._db, Subject, type=type, identifier=identifier)[0] def _classification(self, identifier, subject, data_source, weight=1): return get_one_or_create(self._db, Classification, identifier=identifier, subject=subject, data_source=data_source, weight=weight)[0] def sample_cover_path(self, name): """The path to the sample cover with the given filename.""" base_path = os.path.split(__file__)[0] resource_path = os.path.join(base_path, "tests", "files", "covers") sample_cover_path = os.path.join(resource_path, name) return sample_cover_path def sample_cover_representation(self, name): """A Representation of the sample cover with the given filename.""" sample_cover_path = self.sample_cover_path(name) return self._representation(media_type="image/png", content=open(sample_cover_path).read())[0]
def _fill_test_database(db: Session) -> NoReturn: """Create dummy users and channels to allow further testing in dev mode.""" from quetz.dao import Dao test_users = [] dao = Dao(db) try: for index, username in enumerate(['alice', 'bob', 'carol', 'dave']): user = dao.create_user_with_role(username) identity = Identity( provider='dummy', identity_id=str(index), username=username, ) profile = Profile(name=username.capitalize(), avatar_url='/avatar.jpg') user.identities.append(identity) # type: ignore user.profile = profile db.add(user) test_users.append(user) for channel_index in range(3): channel = Channel( name=f'channel{channel_index}', description=f'Description of channel{channel_index}', private=False, ) for package_index in range(random.randint(5, 10)): package = Package( name=f'package{package_index}', summary=f'package {package_index} summary text', description=f'Description of package{package_index}', ) channel.packages.append(package) # type: ignore test_user = test_users[random.randint(0, len(test_users) - 1)] package_member = PackageMember(package=package, channel=channel, user=test_user, role='owner') db.add(package_member) test_user = test_users[random.randint(0, len(test_users) - 1)] if channel_index == 0: package = Package(name='xtensor', description='Description of xtensor') channel.packages.append(package) # type: ignore package_member = PackageMember(package=package, channel=channel, user=test_user, role='owner') db.add(package_member) # create API key key = uuid.uuid4().hex key_user = User(id=uuid.uuid4().bytes) api_key = ApiKey(key=key, description='test API key', user=test_user, owner=test_user) db.add(api_key) print( f'Test API key created for user "{test_user.username}": {key}' ) key_package_member = PackageMember( user=key_user, channel_name=channel.name, package_name=package.name, role='maintainer', ) db.add(key_package_member) db.add(channel) channel_member = ChannelMember( channel=channel, user=test_user, role='owner', ) db.add(channel_member) db.commit() finally: db.close()
class DatabaseTest(object): engine = None connection = None @classmethod def get_database_connection(cls): url = Configuration.database_url() engine, connection = SessionManager.initialize(url) return engine, connection @classmethod def setup_class(cls): # Initialize a temporary data directory. cls.engine, cls.connection = cls.get_database_connection() cls.old_data_dir = Configuration.data_directory cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp") Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir # Avoid CannotLoadConfiguration errors related to CDN integrations. Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get( Configuration.INTEGRATIONS, {} ) Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {} @classmethod def teardown_class(cls): # Destroy the database connection and engine. cls.connection.close() cls.engine.dispose() if cls.tmp_data_dir.startswith("/tmp"): logging.debug("Removing temporary directory %s" % cls.tmp_data_dir) shutil.rmtree(cls.tmp_data_dir) else: logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir) Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir def setup(self, mock_search=True): # Create a new connection to the database. self._db = Session(self.connection) self.transaction = self.connection.begin_nested() # Start with a high number so it won't interfere with tests that search for an age or grade self.counter = 2000 self.time_counter = datetime(2014, 1, 1) self.isbns = [ "9780674368279", "0636920028468", "9781936460236", "9780316075978" ] if mock_search: self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex) self.search_mock.start() else: self.search_mock = None # TODO: keeping this for now, but need to fix it bc it hits _isbn, # which pops an isbn off the list and messes tests up. so exclude # _ functions from participating. # also attempt to stop nosetest showing docstrings instead of function names. #for name, obj in inspect.getmembers(self): # if inspect.isfunction(obj) and obj.__name__.startswith('test_'): # obj.__doc__ = None def teardown(self): # Close the session. self._db.close() # Roll back all database changes that happened during this # test, whether in the session that was just closed or some # other session. self.transaction.rollback() # Remove any database objects cached in the model classes but # associated with the now-rolled-back session. Collection.reset_cache() ConfigurationSetting.reset_cache() DataSource.reset_cache() DeliveryMechanism.reset_cache() ExternalIntegration.reset_cache() Genre.reset_cache() Library.reset_cache() # Also roll back any record of those changes in the # Configuration instance. for key in [ Configuration.SITE_CONFIGURATION_LAST_UPDATE, Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE ]: if key in Configuration.instance: del(Configuration.instance[key]) if self.search_mock: self.search_mock.stop() def time_eq(self, a, b): "Assert that two times are *approximately* the same -- within 2 seconds." if a < b: delta = b-a else: delta = a-b total_seconds = delta.total_seconds() assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds) def shortDescription(self): return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2. @property def _id(self): self.counter += 1 return self.counter @property def _str(self): return unicode(self._id) @property def _time(self): v = self.time_counter self.time_counter = self.time_counter + timedelta(days=1) return v @property def _isbn(self): return self.isbns.pop() @property def _url(self): return "http://foo.com/" + self._str def _patron(self, external_identifier=None, library=None): external_identifier = external_identifier or self._str library = library or self._default_library return get_one_or_create( self._db, Patron, external_identifier=external_identifier, library=library )[0] def _contributor(self, sort_name=None, name=None, **kw_args): name = sort_name or name or self._str return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args) def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None): if foreign_id: id = foreign_id else: id = self._str return Identifier.for_foreign_id(self._db, identifier_type, id)[0] def _edition(self, data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID, with_license_pool=False, with_open_access_download=False, title=None, language="eng", authors=None, identifier_id=None, series=None, collection=None, publicationDate=None ): id = identifier_id or self._str source = DataSource.lookup(self._db, data_source_name) wr = Edition.for_foreign_id( self._db, source, identifier_type, id)[0] if not title: title = self._str wr.title = unicode(title) wr.medium = Edition.BOOK_MEDIUM if series: wr.series = series if language: wr.language = language if authors is None: authors = self._str if isinstance(authors, basestring): authors = [authors] if authors != []: wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE) wr.author = unicode(authors[0]) for author in authors[1:]: wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE) if publicationDate: wr.published = publicationDate if with_license_pool or with_open_access_download: pool = self._licensepool( wr, data_source_name=data_source_name, with_open_access_download=with_open_access_download, collection=collection ) pool.set_presentation_edition() return wr, pool return wr def _work(self, title=None, authors=None, genre=None, language=None, audience=None, fiction=True, with_license_pool=False, with_open_access_download=False, quality=0.5, series=None, presentation_edition=None, collection=None, data_source_name=None): """Create a Work. For performance reasons, this method does not generate OPDS entries or calculate a presentation edition for the new Work. Tests that rely on this information being present should call _slow_work() instead, which takes more care to present the sort of Work that would be created in a real environment. """ pools = [] if with_open_access_download: with_license_pool = True language = language or "eng" title = unicode(title or self._str) audience = audience or Classifier.AUDIENCE_ADULT if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name: # TODO: This is necessary because Gutenberg's childrens books # get filtered out at the moment. data_source_name = DataSource.OVERDRIVE elif not data_source_name: data_source_name = DataSource.GUTENBERG if fiction is None: fiction = True new_edition = False if not presentation_edition: new_edition = True presentation_edition = self._edition( title=title, language=language, authors=authors, with_license_pool=with_license_pool, with_open_access_download=with_open_access_download, data_source_name=data_source_name, series=series, collection=collection, ) if with_license_pool: presentation_edition, pool = presentation_edition if with_open_access_download: pool.open_access = True pools = [pool] else: pools = presentation_edition.license_pools work, ignore = get_one_or_create( self._db, Work, create_method_kwargs=dict( audience=audience, fiction=fiction, quality=quality), id=self._id) if genre: if not isinstance(genre, Genre): genre, ignore = Genre.lookup(self._db, genre, autocreate=True) work.genres = [genre] work.random = 0.5 work.set_presentation_edition(presentation_edition) if pools: # make sure the pool's presentation_edition is set, # bc loan tests assume that. if not work.license_pools: for pool in pools: work.license_pools.append(pool) for pool in pools: pool.set_presentation_edition() # This is probably going to be used in an OPDS feed, so # fake that the work is presentation ready. work.presentation_ready = True work.calculate_opds_entries(verbose=False) return work def add_to_materialized_view(self, works, true_opds=False): """Make sure all the works in `works` show up in the materialized view. :param true_opds: Generate real OPDS entries for each each work, rather than faking it. """ if not isinstance(works, list): works = [works] for work in works: if true_opds: work.calculate_opds_entries(verbose=False) else: work.presentation_ready = True work.simple_opds_entry = "<entry>an entry</entry>" self._db.commit() SessionManager.refresh_materialized_views(self._db) def _lane(self, display_name=None, library=None, parent=None, genres=None, languages=None, fiction=None ): display_name = display_name or self._str library = library or self._default_library lane, is_new = create( self._db, Lane, library=library, parent=parent, display_name=display_name, fiction=fiction ) if is_new and parent: lane.priority = len(parent.sublanes)-1 if genres: if not isinstance(genres, list): genres = [genres] for genre in genres: if isinstance(genre, basestring): genre, ignore = Genre.lookup(self._db, genre) lane.genres.append(genre) if languages: if not isinstance(languages, list): languages = [languages] lane.languages = languages return lane def _slow_work(self, *args, **kwargs): """Create a work that closely resembles one that might be found in the wild. This is significantly slower than _work() but more reliable. """ work = self._work(*args, **kwargs) work.calculate_presentation_edition() work.calculate_opds_entries(verbose=False) return work def _add_generic_delivery_mechanism(self, license_pool): """Give a license pool a generic non-open-access delivery mechanism.""" data_source = license_pool.data_source identifier = license_pool.identifier content_type = Representation.EPUB_MEDIA_TYPE drm_scheme = DeliveryMechanism.NO_DRM LicensePoolDeliveryMechanism.set( data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT ) def _coverage_record(self, edition, coverage_source, operation=None, status=CoverageRecord.SUCCESS, collection=None, exception=None, ): if isinstance(edition, Identifier): identifier = edition else: identifier = edition.primary_identifier record, ignore = get_one_or_create( self._db, CoverageRecord, identifier=identifier, data_source=coverage_source, operation=operation, collection=collection, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, exception=exception, ) ) return record def _work_coverage_record(self, work, operation=None, status=CoverageRecord.SUCCESS): record, ignore = get_one_or_create( self._db, WorkCoverageRecord, work=work, operation=operation, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, ) ) return record def _licensepool(self, edition, open_access=True, data_source_name=DataSource.GUTENBERG, with_open_access_download=False, set_edition_as_presentation=False, collection=None): source = DataSource.lookup(self._db, data_source_name) if not edition: edition = self._edition(data_source_name) collection = collection or self._default_collection pool, ignore = get_one_or_create( self._db, LicensePool, create_method_kwargs=dict( open_access=open_access), identifier=edition.primary_identifier, data_source=source, collection=collection, availability_time=datetime.utcnow() ) if set_edition_as_presentation: pool.presentation_edition = edition if with_open_access_download: pool.open_access = True url = "http://foo.com/" + self._str media_type = MediaTypes.EPUB_MEDIA_TYPE link, new = pool.identifier.add_link( Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type ) # Add a DeliveryMechanism for this download pool.set_delivery_mechanism( media_type, DeliveryMechanism.NO_DRM, RightsStatus.GENERIC_OPEN_ACCESS, link.resource, ) representation, is_new = self._representation( url, media_type, "Dummy content", mirrored=True) link.resource.representation = representation else: # Add a DeliveryMechanism for this licensepool pool.set_delivery_mechanism( MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, RightsStatus.UNKNOWN, None ) pool.licenses_owned = pool.licenses_available = 1 return pool def _license(self, pool, identifier=None, checkout_url=None, status_url=None, expires=None, remaining_checkouts=None, concurrent_checkouts=None): identifier = identifier or self._str checkout_url = checkout_url or self._str status_url = status_url or self._str license, ignore = get_one_or_create( self._db, License, identifier=identifier, license_pool=pool, checkout_url=checkout_url, status_url=status_url, expires=expires, remaining_checkouts=remaining_checkouts, concurrent_checkouts=concurrent_checkouts, ) return license def _representation(self, url=None, media_type=None, content=None, mirrored=False): url = url or "http://foo.com/" + self._str repr, is_new = get_one_or_create( self._db, Representation, url=url) repr.media_type = media_type if media_type and content: repr.content = content repr.fetched_at = datetime.utcnow() if mirrored: repr.mirror_url = "http://foo.com/" + self._str repr.mirrored_at = datetime.utcnow() return repr, is_new def _customlist(self, foreign_identifier=None, name=None, data_source_name=DataSource.NYT, num_entries=1, entries_exist_as_works=True ): data_source = DataSource.lookup(self._db, data_source_name) foreign_identifier = foreign_identifier or self._str now = datetime.utcnow() customlist, ignore = get_one_or_create( self._db, CustomList, create_method_kwargs=dict( created=now, updated=now, name=name or self._str, description=self._str, ), data_source=data_source, foreign_identifier=foreign_identifier ) editions = [] for i in range(num_entries): if entries_exist_as_works: work = self._work(with_open_access_download=True) edition = work.presentation_edition else: edition = self._edition( data_source_name, title="Item %s" % i) edition.permanent_work_id="Permanent work ID %s" % self._str customlist.add_entry( edition, "Annotation %s" % i, first_appearance=now) editions.append(edition) return customlist, editions def _complaint(self, license_pool, type, source, detail, resolved=None): complaint, is_new = Complaint.register( license_pool, type, source, detail, resolved ) return complaint def _credential(self, data_source_name=DataSource.GUTENBERG, type=None, patron=None): data_source = DataSource.lookup(self._db, data_source_name) type = type or self._str patron = patron or self._patron() credential, is_new = Credential.persistent_token_create( self._db, data_source, type, patron ) return credential def _external_integration(self, protocol, goal=None, settings=None, libraries=None, **kwargs ): integration = None if not libraries: integration, ignore = get_one_or_create( self._db, ExternalIntegration, protocol=protocol, goal=goal ) else: if not isinstance(libraries, list): libraries = [libraries] # Try to find an existing integration for one of the given # libraries. for library in libraries: integration = ExternalIntegration.lookup( self._db, protocol, goal, library=libraries[0] ) if integration: break if not integration: # Otherwise, create a brand new integration specifically # for the library. integration = ExternalIntegration( protocol=protocol, goal=goal, ) integration.libraries.extend(libraries) self._db.add(integration) for attr, value in kwargs.items(): setattr(integration, attr, value) settings = settings or dict() for key, value in settings.items(): integration.set_setting(key, value) return integration def _delegated_patron_identifier( self, library_uri=None, patron_identifier=None, identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, identifier=None ): """Create a sample DelegatedPatronIdentifier""" library_uri = library_uri or self._url patron_identifier = patron_identifier or self._str if callable(identifier): make_id = identifier else: if not identifier: identifier = self._str def make_id(): return identifier patron, is_new = DelegatedPatronIdentifier.get_one_or_create( self._db, library_uri, patron_identifier, identifier_type, make_id ) return patron def _sample_ecosystem(self): """ Creates an ecosystem of some sample work, pool, edition, and author objects that all know each other. """ # make some authors [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob") bob.family_name, bob.display_name = bob.default_names() [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice") alice.family_name, alice.display_name = alice.default_names() edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI, with_license_pool=True, with_open_access_download=True, authors=[]) edition_std_ebooks.title = u"The Standard Ebooks Title" edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle" edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_git.title = u"The GItenberg Title" edition_git.subtitle = u"The GItenberg Subtitle" edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE) edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_gut.title = u"The GUtenberg Title" edition_gut.subtitle = u"The GUtenberg Subtitle" edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE) work = self._work(presentation_edition=edition_git) for p in pool_gut, pool_std_ebooks: work.license_pools.append(p) work.calculate_presentation() return (work, pool_std_ebooks, pool_git, pool_gut, edition_std_ebooks, edition_git, edition_gut, alice, bob) def print_database_instance(self): """ Calls the class method that examines the current state of the database model (whether it's been committed or not). NOTE: If you set_trace, and hit "continue", you'll start seeing console output right away, without waiting for the whole test to run and the standard output section to display. You can also use nosetest --nocapture. I use: def test_name(self): [code...] set_trace() self.print_database_instance() # TODO: remove before prod [code...] """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.") return DatabaseTest.print_database_class(self._db) return @classmethod def print_database_class(cls, db_connection): """ Prints to the console the entire contents of the database, as the unit test sees it. Exists because unit tests don't persist db information, they create a memory representation of the db state, and then roll the unit test-derived transactions back. So we cannot see what's going on by going into postgres and running selects. This is the in-test alternative to going into postgres. Can be called from model and metadata classes as well as tests. NOTE: The purpose of this method is for debugging. Be careful of leaving it in code and potentially outputting vast tracts of data into your output stream on production. Call like this: set_trace() from testing import ( DatabaseTest, ) _db = Session.object_session(self) DatabaseTest.print_database_class(_db) # TODO: remove before prod """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.") return works = db_connection.query(Work).all() identifiers = db_connection.query(Identifier).all() license_pools = db_connection.query(LicensePool).all() editions = db_connection.query(Edition).all() data_sources = db_connection.query(DataSource).all() representations = db_connection.query(Representation).all() if (not works): print "NO Work found" for wCount, work in enumerate(works): # pipe character at end of line helps see whitespace issues print "Work[%s]=%s|" % (wCount, work) if (not work.license_pools): print " NO Work.LicensePool found" for lpCount, license_pool in enumerate(work.license_pools): print " Work.LicensePool[%s]=%s|" % (lpCount, license_pool) print " Work.presentation_edition=%s|" % work.presentation_edition print "__________________________________________________________________\n" if (not identifiers): print "NO Identifier found" for iCount, identifier in enumerate(identifiers): print "Identifier[%s]=%s|" % (iCount, identifier) print " Identifier.licensed_through=%s|" % identifier.licensed_through print "__________________________________________________________________\n" if (not license_pools): print "NO LicensePool found" for index, license_pool in enumerate(license_pools): print "LicensePool[%s]=%s|" % (index, license_pool) print " LicensePool.work_id=%s|" % license_pool.work_id print " LicensePool.data_source_id=%s|" % license_pool.data_source_id print " LicensePool.identifier_id=%s|" % license_pool.identifier_id print " LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id print " LicensePool.superceded=%s|" % license_pool.superceded print " LicensePool.suppressed=%s|" % license_pool.suppressed print "__________________________________________________________________\n" if (not editions): print "NO Edition found" for index, edition in enumerate(editions): # pipe character at end of line helps see whitespace issues print "Edition[%s]=%s|" % (index, edition) print " Edition.primary_identifier_id=%s|" % edition.primary_identifier_id print " Edition.permanent_work_id=%s|" % edition.permanent_work_id if (edition.data_source): print " Edition.data_source.id=%s|" % edition.data_source.id print " Edition.data_source.name=%s|" % edition.data_source.name else: print " No Edition.data_source." if (edition.license_pool): print " Edition.license_pool.id=%s|" % edition.license_pool.id else: print " No Edition.license_pool." print " Edition.title=%s|" % edition.title print " Edition.author=%s|" % edition.author if (not edition.author_contributors): print " NO Edition.author_contributor found" for acCount, author_contributor in enumerate(edition.author_contributors): print " Edition.author_contributor[%s]=%s|" % (acCount, author_contributor) print "__________________________________________________________________\n" if (not data_sources): print "NO DataSource found" for index, data_source in enumerate(data_sources): print "DataSource[%s]=%s|" % (index, data_source) print " DataSource.id=%s|" % data_source.id print " DataSource.name=%s|" % data_source.name print " DataSource.offers_licenses=%s|" % data_source.offers_licenses print " DataSource.editions=%s|" % data_source.editions print " DataSource.license_pools=%s|" % data_source.license_pools print " DataSource.links=%s|" % data_source.links print "__________________________________________________________________\n" if (not representations): print "NO Representation found" for index, representation in enumerate(representations): print "Representation[%s]=%s|" % (index, representation) print " Representation.id=%s|" % representation.id print " Representation.url=%s|" % representation.url print " Representation.mirror_url=%s|" % representation.mirror_url print " Representation.fetch_exception=%s|" % representation.fetch_exception print " Representation.mirror_exception=%s|" % representation.mirror_exception return def _library(self, name=None, short_name=None): name=name or self._str short_name = short_name or self._str library, ignore = get_one_or_create( self._db, Library, name=name, short_name=short_name, create_method_kwargs=dict(uuid=str(uuid.uuid4())), ) return library def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT, external_account_id=None, url=None, username=None, password=None, data_source_name=None): name = name or self._str collection, ignore = get_one_or_create( self._db, Collection, name=name ) collection.external_account_id = external_account_id integration = collection.create_external_integration(protocol) integration.goal = ExternalIntegration.LICENSE_GOAL integration.url = url integration.username = username integration.password = password if data_source_name: collection.data_source = data_source_name return collection @property def _default_library(self): """A Library that will only be created once throughout a given test. By default, the `_default_collection` will be associated with the default library. """ if not hasattr(self, '_default__library'): self._default__library = self.make_default_library(self._db) return self._default__library @property def _default_collection(self): """A Collection that will only be created once throughout a given test. For most tests there's no need to create a different Collection for every LicensePool. Using self._default_collection instead of calling self.collection() saves time. """ if not hasattr(self, '_default__collection'): self._default__collection = self._default_library.collections[0] return self._default__collection @classmethod def make_default_library(cls, _db): """Ensure that the default library exists in the given database. This can be called by code intended for use in testing but not actually within a DatabaseTest subclass. """ library, ignore = get_one_or_create( _db, Library, create_method_kwargs=dict( uuid=unicode(uuid.uuid4()), name="default", ), short_name="default" ) collection, ignore = get_one_or_create( _db, Collection, name="Default Collection" ) integration = collection.create_external_integration( ExternalIntegration.OPDS_IMPORT ) integration.goal = ExternalIntegration.LICENSE_GOAL if collection not in library.collections: library.collections.append(collection) return library def _catalog(self, name=u"Faketown Public Library"): source, ignore = get_one_or_create(self._db, DataSource, name=name) def _integration_client(self, url=None, shared_secret=None): url = url or self._url secret = shared_secret or u"secret" return get_one_or_create( self._db, IntegrationClient, shared_secret=secret, create_method_kwargs=dict(url=url) )[0] def _subject(self, type, identifier): return get_one_or_create( self._db, Subject, type=type, identifier=identifier )[0] def _classification(self, identifier, subject, data_source, weight=1): return get_one_or_create( self._db, Classification, identifier=identifier, subject=subject, data_source=data_source, weight=weight )[0] def sample_cover_path(self, name): """The path to the sample cover with the given filename.""" base_path = os.path.split(__file__)[0] resource_path = os.path.join(base_path, "tests", "files", "covers") sample_cover_path = os.path.join(resource_path, name) return sample_cover_path def sample_cover_representation(self, name): """A Representation of the sample cover with the given filename.""" sample_cover_path = self.sample_cover_path(name) return self._representation( media_type="image/png", content=open(sample_cover_path).read())[0]
def get_table_metadata(database: Database, table_name: str, schema_name: Optional[str]) -> Dict: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. :param database: The database model :param table_name: Table name :param schema_name: schema name :return: Dict table metadata ready for API response """ keys: List = [] columns = database.get_columns(table_name, schema_name) # define comment dict by tsl comment_dict = {} primary_key = database.get_pk_constraint(table_name, schema_name) if primary_key and primary_key.get("constrained_columns"): primary_key["column_names"] = primary_key.pop("constrained_columns") primary_key["type"] = "pk" keys += [primary_key] # get dialect name dialect_name = database.get_dialect().name if isinstance(dialect_name, bytes): dialect_name = dialect_name.decode() # get column comment, presto & hive if dialect_name == "presto" or dialect_name == "hive": db_engine_spec = database.db_engine_spec sql = ParsedQuery("desc {a}.{b}".format(a=schema_name, b=table_name)).stripped() engine = database.get_sqla_engine(schema_name) conn = engine.raw_connection() cursor = conn.cursor() query = Query() session = Session(bind=engine) query.executed_sql = sql query.__tablename__ = table_name session.commit() db_engine_spec.execute(cursor, sql, async_=False) data = db_engine_spec.fetch_data(cursor, query.limit) # parse list data into dict by tsl; hive and presto is different if dialect_name == "presto": for d in data: d[3] comment_dict[d[0]] = d[3] else: for d in data: d[2] comment_dict[d[0]] = d[2] conn.commit() foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes payload_columns: List[Dict] = [] for col in columns: dtype = get_col_type(col) if len(comment_dict) > 0: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": comment_dict[col["name"]], }) elif dialect_name == "mysql": payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": col["comment"], }) else: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], # "comment": col["comment"], }) return { "name": table_name, "columns": payload_columns, "selectStar": database.select_star( table_name, schema=schema_name, show_cols=True, indent=True, cols=columns, latest_partition=True, ), "primaryKey": primary_key, "foreignKeys": foreign_keys, "indexes": keys, }
def manage_slas(self, dag: DAG, session: Session = None) -> None: """ Finding all tasks that have SLAs defined, and sending alert emails where needed. New SLA misses are also recorded in the database. We are assuming that the scheduler runs often, so we only check for tasks that should have succeeded in the past hour. """ self.log.info("Running SLA Checks for %s", dag.dag_id) if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): self.log.info( "Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return qry = (session.query( TI.task_id, func.max(TI.execution_date).label('max_ti')).with_hint( TI, 'USE INDEX (PRIMARY)', dialect_name='mysql').filter(TI.dag_id == dag.dag_id).filter( or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)).filter( TI.task_id.in_(dag.task_ids)).group_by( TI.task_id).subquery('sq')) max_tis: List[TI] = (session.query(TI).filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, TI.execution_date == qry.c.max_ti, ).all()) ts = timezone.utcnow() for ti in max_tis: task = dag.get_task(ti.task_id) if task.sla and not isinstance(task.sla, timedelta): raise TypeError( f"SLA is expected to be timedelta object, got " f"{type(task.sla)} in {task.dag_id}:{task.task_id}") dttm = dag.following_schedule(ti.execution_date) while dttm < timezone.utcnow(): following_schedule = dag.following_schedule(dttm) if following_schedule + task.sla < timezone.utcnow(): session.merge( SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)) dttm = dag.following_schedule(dttm) session.commit() # pylint: disable=singleton-comparison slas: List[SlaMiss] = ( session.query(SlaMiss).filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all()) # pylint: enable=singleton-comparison if slas: # pylint: disable=too-many-nested-blocks sla_dates: List[datetime.datetime] = [ sla.execution_date for sla in slas ] fetched_tis: List[TI] = (session.query(TI).filter( TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id).all()) blocking_tis: List[TI] = [] for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) blocking_tis.append(ti) else: session.delete(ti) session.commit() task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) blocking_task_list = "\n".join(ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications email_sent = False notification_sent = False if dag.sla_miss_callback: # Execute the alert callback self.log.info('Calling SLA miss callback') try: dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: # pylint: disable=broad-except self.log.exception( "Could not call sla_miss_callback for DAG %s", dag.dag_id) email_content = f"""\ Here's a list of tasks that missed their SLAs: <pre><code>{task_list}\n<code></pre> Blocking tasks: <pre><code>{blocking_task_list}<code></pre> Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} """ tasks_missed_sla = [] for sla in slas: try: task = dag.get_task(sla.task_id) except TaskNotFound: # task already deleted from DAG, skip it self.log.warning( "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id) continue tasks_missed_sla.append(task) emails: Set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): emails |= set(get_email_address_list(task.email)) elif isinstance(task.email, (list, tuple)): emails |= set(task.email) if emails: try: send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) email_sent = True notification_sent = True except Exception: # pylint: disable=broad-except Stats.incr('sla_email_notification_failure') self.log.exception( "Could not send SLA Miss email notification for DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: for sla in slas: sla.email_sent = email_sent sla.notification_sent = True session.merge(sla) session.commit()
def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply], session: Session, data_dir: str) -> None: """ * Existing replies are updated in the local database. * New replies have an entry created in the local database. * Local replies not returned in the remote replies are deleted from the local database unless they are pending or failed. If a reply references a new journalist username, add them to the database as a new user. """ local_uuids = {reply.uuid for reply in local_replies} for reply in remote_replies: if reply.uuid in local_uuids: local_reply = [r for r in local_replies if r.uuid == reply.uuid][0] user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session) local_reply.journalist_id = user.id local_reply.size = reply.size local_uuids.remove(reply.uuid) logger.debug('Updated reply {}'.format(reply.uuid)) else: # A new reply to be added to the database. source_uuid = reply.source_uuid source = session.query(Source).filter_by(uuid=source_uuid)[0] user = find_or_create_user( reply.journalist_uuid, reply.journalist_username, session) nr = Reply(uuid=reply.uuid, journalist_id=user.id, source_id=source.id, filename=reply.filename, size=reply.size) session.add(nr) # All replies fetched from the server have succeeded in being sent, # so we should delete the corresponding draft locally if it exists. try: draft_reply_db_object = session.query(DraftReply).filter_by( uuid=reply.uuid).one() update_draft_replies(session, draft_reply_db_object.source.id, draft_reply_db_object.timestamp, draft_reply_db_object.file_counter, nr.file_counter) session.delete(draft_reply_db_object) except NoResultFound: pass # No draft locally stored corresponding to this reply. logger.debug('Added new reply {}'.format(reply.uuid)) # The uuids remaining in local_uuids do not exist on the remote server, so # delete the related records. replies_to_delete = [r for r in local_replies if r.uuid in local_uuids] for deleted_reply in replies_to_delete: delete_single_submission_or_reply_on_disk(deleted_reply, data_dir) session.delete(deleted_reply) logger.debug('Deleted reply {}'.format(deleted_reply.uuid)) session.commit()
class testTradingCenter(unittest.TestCase): def setUp(self): self.trans = connection.begin() self.session = Session(connection) currency = Currency(name='Pesos', code='ARG') self.exchange = Exchange(name='Merval', code='MERV', currency=currency) self.owner = Owner(name='poor owner') self.broker = Broker(name='broker1') self.account = Account(owner=self.owner, broker=self.broker) self.account.deposit(Money(amount=10000, currency=currency)) def tearDown(self): self.trans.rollback() self.session.close() def test_open_orders_by_order_id(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) order=tc.open_order_by_id(order1.id) self.assertEquals(order1, order) order=tc.open_order_by_id(100) self.assertEquals(None, order) def testGetOpenOrdersBySymbol(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) orders=tc.open_orders_by_symbol('symbol') self.assertEquals([order1, order2], list(orders)) def testCancelOrder(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) order1.cancel() self.assertEquals([order2], tc.open_orders) self.assertEquals([order1], tc.cancel_orders) self.assertEquals(CancelOrderStage, type(order1.current_stage)) order2.cancel() self.assertEquals([], tc.open_orders) self.assertEquals([order1, order2], tc.cancel_orders) def testCancelAllOpenOrders(self): security=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=security, price=13.2, share=10) order2=BuyOrder(account=self.account, security=security, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) tc.cancel_all_open_orders() self.assertEquals([], tc.open_orders) def testConsume(self): pass def testPostConsume(self): pass def testCreateAccountWithMetrix(self): pass
from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base Session = sessionmaker(query_cls=CachingQuery) Base = declarative_base(engine=create_engine('sqlite://', echo=True)) class User(Base): __tablename__ = 'users' id = Column(Integer, primary_key=True) name = Column(String(100)) def __repr__(self): return "User(name=%r)" % self.name Base.metadata.create_all() sess = Session() sess.add_all( [User(name='u1'), User(name='u2'), User(name='u3')] ) sess.commit() # cache two user objects sess.query(User).with_cache_key('u2andu3').filter(User.name.in_(['u2', 'u3'])).all() # pull straight from cache print sess.query(User).with_cache_key('u2andu3').all()
#!/usr/bin/python3 """ update state with id = 2 from the database hbtn_0e_6_usa""" from sqlalchemy import create_engine from sqlalchemy.orm.session import Session from model_state import Base, State import sys if __name__ == '__main__': engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format( sys.argv[1], sys.argv[2], sys.argv[3]), pool_pre_ping=True) Base.metadata.create_all(engine) session = Session(engine) query_row = session.query(State).filter(State.id == 2).\ update({"name": "New Mexico"}, synchronize_session="fetch") session.commit() session.close()
def handle_noargs(self, **options): verbosity = 1 #int(options.get('verbosity')) interactive = options.get('interactive') show_traceback = options.get('traceback') self.style = no_style() # Import the 'management' module within each installed app, to register # dispatcher events. for app_name in settings.INSTALLED_APPS: try: import_module('.management', app_name) except ImportError as exc: # This is slightly hackish. We want to ignore ImportErrors # if the "management" module itself is missing -- but we don't # want to ignore the exception if the management module exists # but raises an ImportError for some reason. The only way we # can do this is to check the text of the exception. Note that # we're a bit broad in how we check the text, because different # Python implementations may not use the same text. # CPython uses the text "No module named management" # PyPy uses "No module named myproject.myapp.management" msg = exc.args[0] if not msg.startswith('No module named') or 'management' not in msg: raise db = options.get('database') orm = ORM.get(db) db_info = orm.settings_dict is_test_db = db_info.get('TEST', False) if not is_test_db: print 'Database "%s" cannot be purged because it is not a test ' \ 'database.\nTo flag this as a test database, set TEST to ' \ 'True in the database settings.' % db sys.exit() if interactive: confirm = raw_input('\nYou have requested a purge of database ' \ '"%s" (%s). This will IRREVERSIBLY DESTROY all data ' \ 'currently in the database, and DELETE ALL TABLES AND ' \ 'SCHEMAS. Are you sure you want to do this?\n\n' \ 'Type "yes" to continue, or "no" to cancel: ' \ % (db, orm.engine.url)) else: confirm = 'yes' if confirm == 'yes': # get a list of all schemas used by the app default_schema = orm.engine.url.database app_schemas = set(orm.Base.metadata._schemas) app_schemas.add(default_schema) url = deepcopy(orm.engine.url) url.database = None engine = create_engine(url) inspector = inspect(engine) # get a list of existing schemas db_schemas = set(inspector.get_schema_names()) schemas = app_schemas.intersection(db_schemas) app_tables = set() for table in orm.Base.metadata.tables.values(): schema = table.schema or default_schema app_tables.add('%s.%s' % (schema, table.name)) metadata = MetaData() db_tables = [] all_fks = [] for schema in schemas: for table_name in inspector.get_table_names(schema): fullname = '%s.%s' % (schema, table_name) if fullname not in app_tables: continue fks = [] for fk in inspector.get_foreign_keys(table_name, schema=schema): if not fk['name']: continue fks.append(ForeignKeyConstraint((),(),name=fk['name'])) t = Table(table_name, metadata, *fks, schema=schema) db_tables.append(t) all_fks.extend(fks) session = Session(bind=engine) for fkc in all_fks: session.execute(DropConstraint(fkc)) for table in db_tables: session.execute(DropTable(table)) for schema in schemas: session.execute(DropSchema(schema)) session.commit() session.bind.dispose() else: self.stdout.write("Purge cancelled.\n")
def remove(db_session: Session, shopping_list: models.List): db_session.delete(shopping_list) db_session.commit()
class InboxSession(object): """ Inbox custom ORM (with SQLAlchemy compatible API). Parameters ---------- engine : <sqlalchemy.engine.Engine> A configured database engine to use for this session versioned : bool Do you want to enable the transaction log? ignore_soft_deletes : bool Whether or not to ignore soft-deleted objects in query results. namespace_id : int Namespace to limit query results with. """ def __init__(self, engine, versioned=True, ignore_soft_deletes=True, namespace_id=None): # TODO: support limiting on namespaces assert engine, "Must set the database engine" args = dict(bind=engine, autoflush=True, autocommit=False) self.ignore_soft_deletes = ignore_soft_deletes if ignore_soft_deletes: args['query_cls'] = InboxQuery self._session = Session(**args) if versioned: from inbox.models.transaction import create_revisions @event.listens_for(self._session, 'after_flush') def after_flush(session, flush_context): """ Hook to log revision snapshots. Must be post-flush in order to grab object IDs on new objects. """ create_revisions(session) def query(self, *args, **kwargs): q = self._session.query(*args, **kwargs) if self.ignore_soft_deletes: return q.options(IgnoreSoftDeletesOption()) else: return q def add(self, instance): if not self.ignore_soft_deletes or not instance.is_deleted: self._session.add(instance) else: raise Exception("Why are you adding a deleted object?") def add_all(self, instances): if True not in [i.is_deleted for i in instances] or \ not self.ignore_soft_deletes: self._session.add_all(instances) else: raise Exception("Why are you adding a deleted object?") def delete(self, instance): if self.ignore_soft_deletes: instance.mark_deleted() # just to make sure self._session.add(instance) else: self._session.delete(instance) def begin(self): self._session.begin() def commit(self): self._session.commit() def rollback(self): self._session.rollback() def flush(self): self._session.flush() def close(self): self._session.close() def expunge(self, obj): self._session.expunge(obj) def merge(self, obj): return self._session.merge(obj) @property def no_autoflush(self): return self._session.no_autoflush
def process_file( self, file_path: str, callback_requests: List[CallbackRequest], pickle_dags: bool = False, session: Session = NEW_SESSION, ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. This includes: 1. Execute the file and look for DAG objects in the namespace. 2. Execute any Callbacks if passed to this method. 3. Serialize the DAGs and save it to DB (or update existing record in the DB). 4. Pickle the DAG and save it to the DB (if necessary). 5. Mark any DAGs which are no longer present as inactive 6. Record any errors importing the file into ORM :param file_path: the path to the Python file that should be executed :param callback_requests: failure callback to execute :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :param session: Sqlalchemy ORM Session :return: number of dags found, count of import errors :rtype: Tuple[int, int] """ self.log.info("Processing file %s for tasks to queue", file_path) try: dagbag = DagBag(file_path, include_examples=False) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) return 0, 0 if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, # otherwise they might be stuck in queued/running for ever! self.execute_callbacks_without_dag(callback_requests, session) return 0, len(dagbag.import_errors) self.execute_callbacks(dagbag, callback_requests, session) session.commit() # Save individual DAGs in the ORM dagbag.sync_to_db(session) session.commit() if pickle_dags: paused_dag_ids = DagModel.get_paused_dag_ids( dag_ids=dagbag.dag_ids) unpaused_dags: List[DAG] = [ dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids ] for dag in unpaused_dags: dag.pickle(session) # Record import errors into the ORM try: self.update_import_errors(session, dagbag) except Exception: self.log.exception("Error logging import errors!") # Record DAG warnings in the metadatabase. try: self.update_dag_warnings(session=session, dagbag=dagbag) except Exception: self.log.exception("Error logging DAG warnings.") return len(dagbag.dags), len(dagbag.import_errors)
def delete(cls, id_): obj = cls.query().get(id_) Session.delete(obj) Session.commit()
#2008-05-07 s = sqlalchemy.sql.select([Article.c.body, Article.c.headline], Article.c.headline=='Python is cool!') #connection = eng.connect() result = connection.execute(s) print "Fetched from Article table before commit" for row in result: print row #""" print dir(Article) #rollback all table savings yes_or_no = raw_input("Commit Database Transaction?(y/n):") yes_or_no = yes_or_no.lower() #default is rollback. so no need to take care 'n' or 'no'. if yes_or_no=='y' or yes_or_no=='yes': session.commit() #transaction.commit() #it will also execute session.flush() if it's not executed. else: session.rollback() #transaction.rollback() aa = articles.alias() s = sqlalchemy.sql.select([aa.c.body, aa.c.headline, association.c.article_id], aa.c.article_id==association.c.article_id, order_by=[association.c.article_id]) #connection = eng.connect() result = connection.execute(s).fetchmany(3) print "Fetched from Article table" for row in result: print row.headline print row['headline'] print row
class AlchemyFeatureDatabase(FeatureDatabase): """ Uses SQLAlchemy to provide a a FeatureDatabase backed by an SQL database. IMPORTANT: all insertions must be followed by a finalize() call """ def __init__(self, engine): """ Initializes this feature database using an engine """ if type(engine) == types.StringType: new_engine = create_engine(engine, encoding='utf-8') if "sqlite:" in engine: new_engine.raw_connection().connection.text_factory = str engine = new_engine self._engine = engine self._session = Session(bind=self._engine) self._cache = None self._cache_dirty = False self.initialize() def initialize(self): Base.metadata.create_all(self._engine) def finalize(self): self._session.commit() def _get_session(self): return self._session def add_feature(self, feature): session = self._get_session() feature = Feature(feature) session.add(feature) return True def get_feature(self, qry): session = self._get_session() for obj in session.query(Feature).filter_by(feature=qry): return obj def feature_exists(self, feature): qry = self.get_feature(feature) return qry != None def add_source(self, source): session = self._get_session() source = Source(source) session.add(source) return True def get_source(self, qry): session = self._get_session() for obj in session.query(Source).filter_by(source=qry): return obj def _gen_cache(self): self._cache = {} cache = self._cache session = self._get_session() for obj in session.query(Example)\ .options(joinedload('feature'))\ .options(joinedload('source')): if obj.source.source not in cache: cache[obj.source.source] = {} sub_cache = cache[obj.source.source] feature = obj.feature.feature if type(feature) is not types.StringType and type(feature) is not types.UnicodeType: sub_cache_key = str(obj.feature.feature) else: sub_cache_key = feature if sub_cache_key not in sub_cache: sub_cache[sub_cache_key] = [(obj.label, obj.extra)] else: sub_cache[sub_cache_key].append((obj.label, obj.extra)) self._cache_dirty = False def add_feature_example(self, feature, label, source=None, extra = None): super(AlchemyFeatureDatabase, self).add_feature_example(feature, label, source, extra) if source is None: source = "default" if label not in [-1, 0, 1]: raise ValueError("label: must be -1, 0 or 1") source_obj = self.get_source(source) feature_obj = self.get_feature(feature) if source_obj is None: self.add_source(source) return self.add_feature_example(feature, label, source, extra) if feature_obj is None: self.add_feature(feature) return self.add_feature_example(feature, label, source, extra) session = self._get_session() example_obj = Example(feature_obj, label, source_obj, extra) session.add(example_obj) self._cache_dirty = True return True def get_feature_examples(self, feature, sources=None): if self._cache is None or self._cache_dirty: self._gen_cache() if type(feature) is not types.StringType: if type(feature) is not types.UnicodeType: feature = str(feature) for source in self._cache: if sources is not None: if source not in sources: continue sc = self._cache[source] if feature not in sc: continue for pair in sc[feature]: yield pair def get_all_features(self): if self._cache is None or self._cache_dirty: self._gen_cache() for source in self._cache: sc = self._cache[source] for feature in sc: for label, extra in sc[feature]: yield feature, label, extra
def populate_for_human_evaluation(session: Session, method_to_result: Dict[str, Path]) -> None: if session.query(HumanEvaluation).first() is not None: return if session.query(GenerationResult).first() is not None: return method_names = list(method_to_result.keys()) + ['Gold'] token_delim = '|' with method_to_result['Base'].open(mode='r') as f: N = sum(1 for _ in f) - 1 f.seek(0) reader = csv.reader(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL) next(reader) for _ in tqdm(range(N)): fields = next(reader) article_id = fields[0] base_result = GenerationResult( article_id, 'Base', ''.join(number2kansuuzi(fields[4].split(token_delim)[:-1]))) session.add(base_result) gold_result = GenerationResult(article_id, 'Gold', None) session.add(gold_result) session.commit() for method_name in [m for m in method_names if m not in ['Base', 'Gold']]: if not method_to_result[method_name].is_file(): continue with method_to_result[method_name].open(mode='r') as f: N = sum(1 for _ in f) - 1 f.seek(0) reader = csv.reader(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL) next(reader) for _ in tqdm(range(N)): fields = next(reader) article_id = fields[0] result = GenerationResult( article_id, method_name, ''.join( number2kansuuzi(fields[4].split(token_delim)[:-1]))) session.add(result) session.commit() groups = session \ .query(GenerationResult.article_id) \ .filter(GenerationResult.result != '') \ .group_by(GenerationResult.article_id) \ .all() orderings = list(permutations([m for m in method_names])) for group in groups: i = random.randint(0, math.factorial(len(method_names)) - 1) h = HumanEvaluation(group.article_id, ordering=orderings[i]) session.add(h) session.commit() results = session.query(HumanEvaluation).all() SAMPLE_SIZE = 100 n_results = len(results) if n_results >= SAMPLE_SIZE: sample_indices = random.sample(range(n_results), SAMPLE_SIZE) for i, result in enumerate(results): result.is_target = i in sample_indices session.commit()
class Catalog(object): def __init__(self, engine): self.engine = engine self.session = Session(bind=self.engine) @subtransaction def add_item(self, code, description, long_description=None): citem = models.Item(code=code, description=description, long_description=long_description) self.session.add(citem) return citem @subtransaction def add_items(self, items): for item in items: # KeyErrors not caught if missing required field try: long_desc = item['long description'] except KeyError: long_desc = None citem = models.Item(code=item['code'], description=item['description'], description_lower=item['description'].lower(), long_description=long_desc) primary=True for cat in item['categories']: q = self.session.query(models.Category).\ filter(models.Category.name == cat) cater = q.first() icater = models.ItemCategory(item=citem, category=cater, primary=primary) citem.categories.append(icater) primary=False self.session.add(citem) def add_category(self, name): self.add_categories([name]) @subtransaction def add_categories(self, names): for name in names: cater = models.Category(name=name) self.session.add(cater) # def remove_category: should make sure that zero use # def rename_category @subtransaction def add_item_category(self, item_id, category_id, primary=False): catitem = models.ItemCategory(item_id=item_id, category_id=category_id, primary=primary) self.session.add(catitem) return catitem # def remove_item_category def get_item(self, code, as_object=False): q = self.session.query(models.Item).filter(models.Item.code == code) item = q.first() if item is None: return None if as_object is True: return item return (item.id, item.code, item.description, item.long_description) def remove_item(self, code): self.session.begin(subtransactions=True) q = self.session.query(models.Item).filter(models.Item.code == code).\ delete() self.session.commit() @subtransaction def update_item(self, code, description=None, long_description=None, new_code=None): q = self.session.query(models.Item).filter(models.Item.code == code) item = q.first() if new_code is not None: item.code = new_code if description is not None: item.description = description item.description_lower = description.lower() if long_description is not None: item.long_description = long_description def get_stock(self, code, as_object=False): q = self.session.query(models.Item, models.StockItem).\ with_entities(models.StockItem).\ filter(models.Item.code == code).\ filter(models.Item.id == models.StockItem.item_id) res = q.first() if res is None: return None if as_object is True: return res return (res.id, res.price, res.count) @subtransaction def add_stock(self, items): for item in items: q = self.session.query(models.Item).\ filter(models.Item.code == item['code']) citem = q.first() stock = models.StockItem(item=citem, count=item['count'], price=item['price']) self.session.add(stock) @subtransaction def update_stock(self, code, count, price): q = self.session.query(models.Item, models.StockItem).\ with_entities(models.StockItem).\ join(models.StockItem.item).\ filter(models.Item.code == code) stock = q.first() if stock is None: q = self.session.query(models.Item).\ filter(models.Item.code == code) item = q.first() if item is not None: stock = models.StockItem(item=item, count=count, price=price) self.session.add(stock) else: raise KeyError('Unknown item code') else: stock.count += count stock.price = price @subtransaction def add_items_with_stock(self, items): categories = set() for item in items: for x in item['categories']: categories.add(x) self.add_categories(categories) self.add_items(items) self.add_stock(items) def list_items(self, sort_key='description', ascending=True, page=1, page_size=10): sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved).\ join(models.StockItem).\ join(models.ItemCategory, models.Item.id == models.ItemCategory.item_id).\ join(models.Category, models.Category.id == models.ItemCategory.category_id).\ filter(models.ItemCategory.primary == True).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id) q = ordering(q, ascending, sort_key) q = pagination(q, page, page_size) res = [item_to_json(*x) for x in q] return res def search_items(self, prefix, price_range, sort_key='description', ascending=True, page=1, page_size=10): sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved).\ join(models.StockItem).\ join(models.ItemCategory, models.Item.id == models.ItemCategory.item_id).\ join(models.Category, models.Category.id == models.ItemCategory.category_id).\ filter(models.StockItem.price.between(*price_range), models.Item.description_lower.like('{:s}%'.format(prefix.lower())), models.ItemCategory.primary == True).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id) q = ordering(q, ascending, sort_key) q = pagination(q, page, page_size) res = [item_to_json(*x) for x in q] return res def list_items_by_prices(self, prices, sort_key='price', prefix=None, ascending=True, page=1, page_size=10): pgs = pg_cases(prices) pg_case = sqla.case(pgs, else_ = -1).label('price_group') sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(pg_case, models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved) if prefix is not None: q = q.filter(models.Item.description_lower.like('{:s}%'.format(prefix.lower()))) q = q.join(models.StockItem.item).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id).\ filter(models.ItemCategory.item_id == models.Item.id, models.ItemCategory.category_id == models.Category.id, models.ItemCategory.primary == True, pg_case >= 0) q = pg_ordering(q, ascending) q = pagination(q, page, page_size) def to_dict(x): tmp = item_to_json(*x[1:]) tmp.update({'price_group': x[0]}) return tmp res = [to_dict(x) for x in q] return res def _get_reservations(self, stock_id): q = self.session.query(func.sum(models.Reservation.count)).\ filter(models.StockItem.id == stock_id).\ filter(models.StockItem.id == models.Reservation.stock_item_id) res = q.first() if res is None or res[0] is None: return 0 return res[0] def _get_reservation(self, basket_item_id): q = self.session.query(models.Basket, models.BasketItem, models.Reservation).\ with_entities(models.Reservation).\ join(models.Reservation.basket_item).\ filter(models.BasketItem.id == basket_item_id) res = q.first() return res @subtransaction def _update_reservation(self, stock, basket_item): reservations = self._get_reservations(basket_item.stock_item_id) # can reserve (scount - reservations) reservation = self._get_reservation(basket_item.id) if reservation is not None: rcount = min(basket_item.count, stock.count - reservations + reservation.count) reservation.count = rcount else: rcount = min(basket_item.count, stock.count - reservations) reservation = models.Reservation(stock_item=stock, basket_item=basket_item, count=rcount) self.session.add(reservation)