def test_ui_cancel_withdraw( logged_in_wallet_user_browser: DriverAPI, dbsession: Session, user_phone_number, top_up_user, eth_asset_id ): """Create new account through UI.""" # Record balance before cancel with transaction.manager: uca = dbsession.query(UserCryptoAddress).first() asset = dbsession.query(Asset).get(eth_asset_id) original_balance = uca.get_crypto_account(asset).account.get_balance() # Go to address b = logged_in_wallet_user_browser b.find_by_css("#nav-wallet").click() b.find_by_css("#row-asset-{} a.withdraw-asset".format(eth_asset_id)).click() b.fill("address", TEST_ADDRESS) b.fill("amount", "0.1") b.find_by_css("button[name='process']").click() # We should arrive to the confirmation page assert b.is_element_present_by_css("#heading-confirm-withdraw") b.find_by_css("button[name='cancel']").click() # We got text box telling withdraw was cancelled assert b.is_element_present_by_css("#msg-withdraw-cancelled") # Balance should be back to original with transaction.manager: uca = dbsession.query(UserCryptoAddress).first() asset = dbsession.query(Asset).get(eth_asset_id) assert original_balance == uca.get_crypto_account(asset).account.get_balance()
def test_get_object(self): """ Test the method _get_object() using a sync key Test scenario: Get the object with sync_key """ obj_type = self.powerVCMapping.obj_type sync_key = self.powerVCMapping.sync_key self.aMox.StubOutWithMock(session, 'query') session.query(model.PowerVCMapping).AndReturn(query) self.aMox.StubOutWithMock(query, 'filter_by') query.filter_by( obj_type=obj_type, sync_key=sync_key).AndReturn(query) self.aMox.StubOutWithMock(query, 'one') query.one().AndReturn(self.powerVCMapping) self.aMox.ReplayAll() returnValue = self.powervcagentdb._get_object( obj_type=obj_type, sync_key=sync_key) self.aMox.VerifyAll() self.assertEqual(returnValue, self.powerVCMapping) self.aMox.UnsetStubs()
def upgrade(): from sqlalchemy.orm.session import Session session = Session(bind=op.get_bind()) # Find duplicates for vcs in session.query(Vcs).group_by(Vcs.repository_id, Vcs.revision).having(func.count(Vcs.id) > 1).all(): print(vcs) # Find all vcs entries with this duplication dupes = session.query(Vcs).filter(Vcs.repository_id == vcs.repository_id).filter(Vcs.revision == vcs.revision).all() # Keep the first and remove the others - thus we need to update references to others to the first for update in dupes[1:]: for af in session.query(Artifakt).filter(Artifakt.vcs_id == update.id).all(): print("Updating artifakt {} to point to vcs {}".format(af.sha1, dupes[0].id)) af.vcs_id = dupes[0].id print("Deleting vcs {}".format(update.id)) session.delete(update) session.commit() if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = OFF") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 0") else: raise NotImplemented with op.batch_alter_table('vcs', schema=None) as batch_op: batch_op.create_unique_constraint('rr', ['repository_id', 'revision']) if session.bind.dialect.name == "sqlite": session.execute("PRAGMA foreign_keys = ON") elif session.bind.dialect.name == "mysql": session.execute("SET foreign_key_checks = 1")
def list_videos(session: Session, count): if int(count) > 0: qry = reversed(session.query(Video).order_by(Video.id.desc()).limit(int(count)).all()) else: qry = session.query(Video).order_by(Video.id.asc()) for vid in qry: print("Video %s:" % vid.id) print("\tTitle: %s" % vid.title) print("\tDate Added: %s" % vid.date_added) if vid.description: print("\tDescription:\n\t\t%s" % vid.description.replace("\n", "\n\t\t")) else: print("\tDescription: None") print("\tKeywords: %s" % vid.keywords) print("\tGame: %s" % vid.game) print("\tStart Segment: %s" % vid.start_segment) print("\tEnd Segment: %s" % vid.end_segment) print("\tProcessing:") print("\t\tDone: %s" % vid.processing_done) print("\t\tStatus: %s" % str(vid.processing_status).replace("\n", "\n\t\t\t")) print("\tTwitch:") print("\t\tDo: %s" % vid.do_twitch) print("\t\tDone: %s" % vid.done_twitch) print("\t\tStatus: %s" % str(vid.twitch_status).replace("\n", "\n\t\t\t")) print("\tYouTube:") print("\t\tDo: %s" % vid.do_youtube) print("\t\tDone: %s" % vid.done_youtube) print("\t\tStatus: %s" % str(vid.youtube_status).replace("\n", "\n\t\t\t")) print("\t\tPublish Date: %s" % vid.youtube_pubdate) print("\t\tPublic: %s" % vid.youtube_public) if vid.comment: print("\tComment:\n\t\t%s" % vid.comment.replace("\n", "\n\t\t"))
def test_set_object_local_id(self): """ Test the method _set_object_local_id(self, obj, local_id) Test scenario: Set the local_id of the specified object when the pvc_id is none """ obj_id = self.powerVCMapping.id self.powerVCMapping.pvc_id = None self.powerVCMapping.local_id = None self.powerVCMapping.status = None self.aMox.StubOutWithMock(session, 'query') session.query(model.PowerVCMapping).AndReturn(query) self.aMox.StubOutWithMock(query, 'filter_by') query.filter_by(id=obj_id).AndReturn(query) self.aMox.StubOutWithMock(query, 'one') query.one().AndReturn(self.powerVCMapping) self.aMox.StubOutWithMock(session, 'merge') session.merge(self.powerVCMapping).AndReturn("") self.aMox.ReplayAll() self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test') self.aMox.VerifyAll() self.assertEqual(self.powerVCMapping.status, 'Creating') self.assertEqual(self.powerVCMapping.local_id, 'test') self.aMox.UnsetStubs()
def test_delete_existing_object(self): """ Test the method _delete_object(self, obj) when the object exists Test scenario: When the data is in the database, the delete operation should complete successfully """ self.aMox.StubOutWithMock(session, 'query') session.query(model.PowerVCMapping).AndReturn(query) self.aMox.StubOutWithMock(query, 'filter_by') query.filter_by(id=self.powerVCMapping['id']).AndReturn(query) self.aMox.StubOutWithMock(query, 'one') query.one().AndReturn(self.powerVCMapping) self.aMox.StubOutWithMock(session, 'begin') session.begin(subtransactions=True).AndReturn(transaction(None, None)) self.aMox.StubOutWithMock(session, 'delete') returnValue = session.delete(self.powerVCMapping).AndReturn(True) self.aMox.ReplayAll() self.powervcagentdb._delete_object(self.powerVCMapping) self.aMox.VerifyAll() self.assertEqual(returnValue, True) self.aMox.UnsetStubs()
class SbciFinanceDB(object): '''TODO''' def __init__(self, verbose=0, *args, **kwds): # @UnusedVariable super(SbciFinanceDB, self).__init__(*args, **kwds) if not os.access(FINANCEDB_FILE, os.R_OK | os.W_OK): raise RuntimeError('cannot access Finance DB file ({}) for R/W!' .format(FINANCEDB_FILE)) self.Base = automap_base() self.engine = create_engine('sqlite:///' + FINANCEDB_FILE) self.Base.prepare(self.engine, reflect=True) self.Categories = self.Base.classes.categories self.Seasons = self.Base.classes.seasons self.Cheques = self.Base.classes.cheques self.Transactions = self.Base.classes.transactions self.Trybooking = self.Base.classes.trybooking self.dbsession = Session(self.engine) self.categories_query = self.dbsession.query(self.Categories) self.seasons_query = self.dbsession.query(self.Seasons) self.cheques_query = self.dbsession.query(self.Cheques) self.transactions_query = self.dbsession.query(self.Transactions) self.trybooking_query = self.dbsession.query(self.Trybooking)
def test_ui_confirm_withdraw( logged_in_wallet_user_browser: DriverAPI, dbsession: Session, user_phone_number, top_up_user, eth_asset_id ): """Create new account through UI.""" # Go to address b = logged_in_wallet_user_browser b.find_by_css("#nav-wallet").click() b.find_by_css("#row-asset-{} a.withdraw-asset".format(eth_asset_id)).click() b.fill("address", TEST_ADDRESS) b.fill("amount", "0.1") b.find_by_css("button[name='process']").click() # We should arrive to the confirmation page assert b.is_element_present_by_css("#heading-confirm-withdraw") # Peek into SMS code with transaction.manager: # Withdraw is the firt user op on the stack withdraw = ( dbsession.query(UserCryptoOperation) .join(CryptoOperation) .order_by(CryptoOperation.created_at.desc()) .first() ) confirmation = UserWithdrawConfirmation.get_pending_confirmation(withdraw) sms_code = confirmation.other_data["sms_code"] b.fill("code", sms_code) b.find_by_css("button[name='confirm']").click() # We got text box telling withdraw was success assert b.is_element_present_by_css("#msg-withdraw-confirmed")
def post_process_video(session: Session, debug): vids = session.query(Video) vids = vids.filter(Video.processing_done) vids = vids.filter(not_(Video.post_processing_done)) vids = vids.filter(or_(and_(Video.do_twitch, Video.done_twitch), not_(Video.do_twitch))) vids = vids.filter(or_(and_(Video.do_youtube, Video.done_youtube), not_(Video.do_youtube))) vids = vids.order_by(Video.id.asc()) vid = vids.first() if not vid: print("No video in need of processing found") return 0 out_dir = get_setting(session, "output_directory") final_dir = get_setting(session, "final_output_directory") out_prefix = get_setting(session, "output_video_prefix") out_ext = get_setting(session, "output_video_extension") out_fname = "%s%s%s" % (out_prefix, vid.id, out_ext) out_path = os.path.join(out_dir, out_fname) final_path = os.path.join(final_dir, out_fname) if out_path != final_path: shutil.move(out_path, final_path) vid.post_processing_status = "Moved %s to %s" % (out_path, final_path) else: vid.post_processing_status = "Nothing to do" vid.post_processing_done = True session.commit() return 0
def upgrade(): context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry).filter_by(customer_request_id=None): for trac in tp.project.tracs: cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone() sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone() tp.customer_request_id = sql_cr.id print sql_cr.id session.commit()
def job_status(self, job_id): """Return information about a job request""" s = Session() r = s.query(JobRequest).get(job_id) if not r: return self._failed("Job %s not found" % job_id, 404) retval = r.asDict() return self._ok(retval)
def jobs(self): """Return a list of past self-serve requests""" s = Session() try: num_jobs = IntValidator.to_python(request.GET.get('num', '100')) offset = IntValidator.to_python(request.GET.get('offset', '0')) except formencode.Invalid: num_jobs = 100 offset = 0 jobs = s.query(JobRequest).order_by(JobRequest.when.desc()).limit(num_jobs).offset(offset) return self._ok([j.asDict() for j in jobs])
def run(self, keyword): kw = None session = Session(bind = self.engine) it = session.query(Keyword).filter_by(word = keyword) try: kw = it.one() except NoResultFound: return None ret = kw.id session.close() return ret
def upgrade(): ### commands auto generated by Alembic - please adjust! ### op.add_column('time_entries', sa.Column('tickettitle', sa.Unicode(), nullable=True)) context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry): for trac in tp.project.tracs: ticket = session.execute('select summary from "trac_%s".ticket where id=%s' % (trac.trac_name, tp.ticket)).fetchone() tp.tickettitle = ticket.summary session.commit()
class SbciTeamsDB(object): '''TODO''' def __init__(self, verbose=0, *args, **kwds): super(SbciTeamsDB, self).__init__(*args, **kwds) if not os.access(TEAMSDB_FILE, os.R_OK | os.W_OK): raise RuntimeError('cannot access DB file ({}) for R/W!' .format(TEAMSDB_FILE)) self.Base = automap_base() self.engine = create_engine('sqlite:///' + TEAMSDB_FILE) self.Base.prepare( self.engine, reflect=True, classname_for_table=_classname_for_table, name_for_scalar_relationship=_name_for_scalar_relationship, name_for_collection_relationship=_name_for_collection_relationship ) self.Competitions = self.Base.classes.competitions self.People = self.Base.classes.people self.Venues = self.Base.classes.venues self.Teams = self.Base.classes.teams self.Sessions = self.Base.classes.sessions if verbose > 1: for cls in (self.Competitions, self.People, self.Venues, self.Teams, self.Sessions): print('{} = {}'.format(cls.__name__, dir(cls))) self.dbsession = Session(self.engine) self.competitions_query = self.dbsession.query(self.Competitions) self.people_query = self.dbsession.query(self.People) self.venues_query = self.dbsession.query(self.Venues) self.teams_query = self.dbsession.query(self.Teams) self.sessions_query = self.dbsession.query(self.Sessions)
def main(): core.configure_logging() p = ProcessQueue() engine = core.get_database_engine_string() logging.info("Using connection string '%s'" % (engine,)) engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED") logging.info("Binding session...") session = Session(bind=engine, autocommit = False) for article in session.query(RawArticle).filter_by(status = 'Unprocessed'): p.add_id(article.id)
def get_last_publish_date(session: Session): hvid = session.query(Video) \ .filter(Video.youtube_pubdate.isnot(None)) \ .order_by(Video.youtube_pubdate.desc()) \ .first() if not hvid: res = get_setting(session, "youtube_schedule_start_ts") if res <= 0: return datetime.datetime.utcnow() else: offset = get_setting(session, "schedule_offset_hours") return datetime.datetime.utcfromtimestamp(res) - datetime.timedelta(hours=offset) else: return hvid.youtube_pubdate
def check_empty_site_init(self, dbsession: Session, user: UserMixin): """Call after user creation to see if this user is the first user and should get initial admin rights.""" assert user.id, "Please flush your db" # Try to reflect related group class based on User model i = inspection.inspect(user.__class__) Group = i.relationships["groups"].mapper.entity # If we already have groups admin group must be there if dbsession.query(Group).count() > 0: return self.init_empty_site(dbsession, user)
def delete_video(session: Session, vid_id: int): vid = session.query(Video).get(vid_id) if not vid: print("Video ID not found") return -1 print("Deleting video %s: %s" % (vid.id, vid.title)) session.delete(vid) session.commit() print("Done") return 0
def get_setting(session: Session, name: str): sett = session.query(Setting).filter_by(name=name).first() if not sett: return None if not sett.value_class: return sett.value elif sett.value_class == "bool": return sett.value.lower() in ["yes", "true", "t", "y", "1"] elif sett.value_class in ["str", "int", "float"]: act_class = eval(sett.value_class) return act_class(sett.value) elif sett.value_class == "pickle": return pickle.loads(sett.value) else: return sett.value
def cache_domains(): core.configure_logging('debug') from backend.db import Domain engine = core.get_database_engine_string() logging.info("Using connection string '%s'" % (engine,)) engine = create_engine(engine, encoding='utf-8', isolation_level="READ UNCOMMITTED") session = Session(bind=engine, autocommit = False) logging.debug("Establishing connection to redis...") r = get_redis_instance(2) it = session.query(Domain) for d in it: r.set(d.key, d.id) logging.info("Sent %s to the cache.", d)
def test_get_objects_with_status(self): """Test the method def _get_objects(self, obj_type, status) Test scenario: Get the object when the status is not None """ self.aMox.StubOutWithMock(session, 'query') session.query(model.PowerVCMapping).AndReturn(query) self.aMox.StubOutWithMock(query, 'filter_by') query.filter_by(obj_type=self.powerVCMapping.obj_type, status=self.powerVCMapping.status).AndReturn(query) self.aMox.StubOutWithMock(query, 'all') query.all().AndReturn(self.powerVCMapping) self.aMox.ReplayAll() returnValue = self.powervcagentdb._get_objects( obj_type=self.powerVCMapping.obj_type, status=self.powerVCMapping.status) self.aMox.VerifyAll() self.assertEqual(returnValue, self.powerVCMapping) self.aMox.UnsetStubs()
def set_new_setting(session: Session, name: str, value): if isinstance(value, str) or isinstance(value, int) or isinstance(value, float) or isinstance(value, bool): value_class = str(value.__class__.__name__) value = str(value) else: value_class = "pickle" value = pickle.dumps(value) sett = session.query(Setting).filter_by(name=name).first() if not sett: sett = Setting(name=name, value=value, value_class=value_class) session.add(sett) else: sett.value = value sett.value_class = value_class return True
def upgrade(): context = op.get_context() session = Session() session.bind = context.bind for b in session.query(KanbanBoard): if b.json: columns = loads(b.json) for n, column in enumerate(columns): for m, task in enumerate(column['tasks']): new_id = update_id(task['id']) print task['id'], new_id if new_id != task['id']: columns[n]['tasks'][m]['id'] = new_id b.json = dumps(columns) session.commit()
def set_setting(session: Session, name: str, value): sett = session.query(Setting).filter_by(name=name).first() if not sett: print("Trying to set unknown setting: %s" % name) return False else: if sett.value_class in ["str", "int", "float"]: vc = eval(sett.value_class) value = vc(value) elif sett.value_class == "bool": value = str(value).lower() in ["yes", "true", "t", "y", "1"] if sett.value_class == "pickle": sett.value = pickle.dumps(value) else: sett.value = str(value) return True
def upgrade(): ### commands auto generated by Alembic - please adjust! ### op.add_column('time_entries', sa.Column('customer_request_id', sa.String(), nullable=True)) op.drop_column('time_entries', u'contract_id') op.drop_column('customer_requests', u'placement') context = op.get_context() session = Session() session.bind = context.bind for tp in session.query(TimeEntry): for trac in tp.project.tracs: cr = session.execute('select value from "trac_%s".ticket_custom where name=\'customerrequest\' and ticket=%s' % (trac.trac_name, tp.ticket)).fetchone() sql_cr = session.execute('select id from customer_requests where id=\'%s\'' % cr.value).fetchone() tp.customer_request_id = sql_cr.id session.commit()
class VMDatabase(object): def __init__(self, path=os.path.expanduser("~/vm.db")): engine = create_engine('sqlite:///%s' % path, echo=False) metadata.create_all(engine) Session = sessionmaker(bind=engine, autoflush=True, autocommit=False) self.session = Session() def print_state(self): for provider in self.getProviders(): print 'Provider:', provider.name for base_image in provider.base_images: print ' Base image:', base_image.name for snapshot_image in base_image.snapshot_images: print ' Snapshot:', snapshot_image.name, \ snapshot_image.state for machine in base_image.machines: print ' Machine:', machine.id, machine.name, \ machine.state, machine.state_time, machine.ip def abort(self): self.session.rollback() def commit(self): self.session.commit() def delete(self, obj): self.session.delete(obj) def getProviders(self): return self.session.query(Provider).all() def getProvider(self, name): return self.session.query(Provider).filter_by(name=name)[0] def getResult(self, id): return self.session.query(Result).filter_by(id=id)[0] def getMachine(self, id): return self.session.query(Machine).filter_by(id=id)[0] def getMachineByJenkinsName(self, name): return self.session.query(Machine).filter_by(jenkins_name=name)[0] def getMachineForUse(self, image_name): """Atomically find a machine that is ready for use, and update its state.""" image = None for machine in self.session.query(Machine).filter( machine_table.c.state == READY).order_by( machine_table.c.state_time): if machine.base_image.name == image_name: machine.state = USED self.commit() return machine raise Exception("No machine found for image %s" % image_name)
def init_empty_site(self, dbsession: Session, user: UserMixin): """When the first user signs up build the admin groups and make the user member of it. Make the first member of the site to be admin and superuser. """ # Try to reflect related group class based on User model i = inspection.inspect(user.__class__) Group = i.relationships["groups"].mapper.entity # Do we already have any groups... if we do we probably don'¨t want to init again if dbsession.query(Group).count() > 0: return g = Group(name=Group.DEFAULT_ADMIN_GROUP_NAME) dbsession.add(g) g.users.append(user)
def index_update(class_name, items): """ items: dict of model class name => list of (operation, primary key) """ cls_registry = dict([(cls.__name__, cls) for cls in service.indexed_classes]) model_class = cls_registry.get(class_name) if model_class is None: raise ValueError("Invalid class: {}".format(class_name)) index = service.index_for_model_class(model_class) primary_field = model_class.search_query.primary indexed_fields = model_class.whoosh_schema.names() session = Session(bind=db.session.get_bind(None, None)) query = session.query(model_class) with AsyncWriter(index) as writer: for change_type, model_pk in items: if model_pk is None: continue # delete everything. stuff that's updated or inserted will get # added as a new doc. Could probably replace this with a whoosh # update. writer.delete_by_term(primary_field, unicode(model_pk)) if change_type in ("new", "changed"): model = query.get(model_pk) if model is None: # deleted after task queued, but before task run continue # Hack: Load lazy fields # This prevents a transaction error in make_document for key in indexed_fields: getattr(model, key) document = service.make_document(model, indexed_fields, primary_field) writer.add_document(**document) session.close()
def __init__(self, *args, **kwargs): self.session = kwargs.pop("__session", None) self.query_tree = None self.entity_class_registry = None self._autoflush = True self.read_deleted = "no" # Sometimes the query must return attributes rather than objects. The following block # is in charge of finding which attributes should be returned by the query. attributes = [] for arg in args: from sqlalchemy.orm.attributes import InstrumentedAttribute if type(arg) is InstrumentedAttribute: attributes += [arg] self.required_attributes = attributes # Create the SQLAlchemy query if "__query" in kwargs: self.sa_query = kwargs["__query"] else: if self.session: from sqlalchemy.orm.session import Session temporary_session = Session() self.sa_query = temporary_session.query(*args, **kwargs) else: from sqlalchemy.orm import Query as SqlAlchemyQuery from rome.core.session.session import Session as RomeSession # self.sa_query = SqlAlchemyQuery(*args, **kwargs) entities = filter(lambda arg: type(arg) is not RomeSession and arg is not None, args) session_candidates = filter(lambda arg: type(arg) is RomeSession, args) if len(session_candidates) > 0: session_candidate = session_candidates[0] else: session_candidate = None self.session = session_candidate if len(attributes) == 0: self.sa_query = SqlAlchemyQuery(entities, session=session_candidate) else: self.sa_query = SqlAlchemyQuery(attributes, session=session_candidate)
def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION): """Returns run_ids of DAG execution""" last_dagrun = dag.get_last_dagrun(include_externally_triggered=True) current_dagrun = dag.get_dagrun(run_id=run_id) first_dagrun = (session.query(DagRun).filter( DagRun.dag_id == dag.dag_id).order_by( DagRun.execution_date.asc()).first()) if last_dagrun is None: raise ValueError(f'DagRun for {dag.dag_id} not found') # determine run_id range of dag runs and tasks to consider end_date = last_dagrun.logical_date if future else current_dagrun.logical_date start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date if not dag.timetable.can_run: # If the DAG never schedules, need to look at existing DagRun if the user wants future or # past runs. dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date) run_ids = sorted({d.run_id for d in dag_runs}) elif not dag.timetable.periodic: run_ids = [run_id] else: dates = [ info.logical_date for info in dag.iter_dagrun_infos_between( start_date, end_date, align=False) ] run_ids = [ dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, execution_date=dates) ] return run_ids
def __get_concurrency_maps( self, states: List[TaskInstanceState], session: Session = None ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]: """ Get the concurrency maps. :param states: List of states to query for :type states: list[airflow.utils.state.State] :return: A map from (dag_id, task_id) to # of task instances and a map from (dag_id, task_id) to # of task instances in the given state list :rtype: tuple[dict[str, int], dict[tuple[str, str], int]] """ ti_concurrency_query: List[Tuple[str, str, int]] = ( session.query(TI.task_id, TI.dag_id, func.count('*')) .filter(TI.state.in_(states)) .group_by(TI.task_id, TI.dag_id) ).all() dag_map: DefaultDict[str, int] = defaultdict(int) task_map: DefaultDict[Tuple[str, str], int] = defaultdict(int) for result in ti_concurrency_query: task_id, dag_id, count = result dag_map[dag_id] += count task_map[(dag_id, task_id)] = count return dag_map, task_map
def get_issues(session: Session, repo_dir: str = None, author: str = None, pull_requests=True) -> List[Issue]: """ Gets a list of issues. Args: session (Session): The database session. repo_dir (str, optional): The issue repository. Defaults to None. author (str, optional): The issues author. Defaults to None. pull_requests (bool, optional): If set to False, it returns the matching issues,including pull requests. Defaults to True. Returns: List[Issue]: List of issues. """ query: Query = session.query(Issue) if repo_dir: query = query.filter_by(repo_dir=repo_dir) if author: query = query.filter_by(author=author) if not pull_requests: query = query.filter_by(is_pull_request=pull_requests) return query.all()
def check_run_id_null(session: Session) -> Iterable[str]: import sqlalchemy.schema metadata = sqlalchemy.schema.MetaData(session.bind) try: metadata.reflect(only=[DagRun.__tablename__], extend_existing=True, resolve_fks=False) except exc.InvalidRequestError: # Table doesn't exist -- empty db return # We can't use the model here since it may differ from the db state due to # this function is run prior to migration. Use the reflected table instead. dagrun_table = metadata.tables[DagRun.__tablename__] invalid_dagrun_filter = or_( dagrun_table.c.dag_id.is_(None), dagrun_table.c.run_id.is_(None), dagrun_table.c.execution_date.is_(None), ) invalid_dagrun_count = session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count() if invalid_dagrun_count > 0: dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2") if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): yield _format_dangling_error( source_table=dagrun_table.name, target_table=dagrun_dangling_table_name, invalid_count=invalid_dagrun_count, reason="with a NULL dag_id, run_id, or execution_date", ) return _move_dangling_data_to_new_table( session, dagrun_table, dagrun_table.select(invalid_dagrun_filter), dagrun_dangling_table_name, )
def get_workitems_related_to_workitem( jtrace: Session, workitem_id: int, user: UKRDCUser ) -> Query: """Get a list of WorkItems related via the LinkRecord network to a given WorkItem Args: jtrace (Session): JTRACE SQLAlchemy session workitem_id (int): WorkItem ID user (UKRDCUser): Logged-in user Returns: Query: SQLAlchemy query """ workitem = get_workitem(jtrace, workitem_id, user) seen_master_ids: set[int] = set() seen_person_ids: set[int] = set() if workitem.master_record: seen_master_ids.add(workitem.master_id) if workitem.person: seen_person_ids.add(workitem.person_id) related_master_ids, related_person_ids = find_related_ids( jtrace, seen_master_ids, seen_person_ids ) other_workitems = jtrace.query(WorkItem).filter( or_( WorkItem.master_id.in_(related_master_ids), WorkItem.person_id.in_(related_person_ids), ) ) other_workitems = other_workitems.filter(WorkItem.id != workitem.id) return _apply_query_permissions(other_workitems, user)
def upgrade(): op.add_column("chat", sa.Column("tag_mode", sa.String(), nullable=True)) session = Session(bind=op.get_bind()) # Set all changes to reviewed, where an task exists session.query(Chat).filter(Chat.fix_single_sticker).update( {"tag_mode": TagMode.single_sticker.value}) session.query(Chat).filter(Chat.tagging_random_sticker).update( {"tag_mode": TagMode.random.value}) session.query(Chat).filter(Chat.full_sticker_set).update( {"tag_mode": TagMode.sticker_set.value}) op.drop_index("ix_chat_current_sticker_set_name", table_name="chat") op.drop_constraint("chat_current_sticker_set_name_fkey", "chat", type_="foreignkey") op.drop_column("chat", "current_sticker_set_name") op.drop_constraint("only_one_action_check", "chat")
def test_authenticate_user_created_from_facebook_auth_ok( session: Session, application: Application): assert not session.query(User).count() assert not session.query(SocialAccount).count() mock_id = "mockid" access_token = "acctoken" email = "*****@*****.**" first_name = "mockfname" last_name = "mocklname" role = UserRole.TENANT result = application.authenticate_facebook(session, mock_id, access_token, email, first_name, last_name, role) assert isinstance(result, bytes) result = application.authenticate_facebook( session, mock_id, "diff_acc_token", email, "diff_first_name", "diff_last_name", role, ) assert isinstance(result, bytes) assert session.query(User).count() == 1 assert session.query(SocialAccount).count() == 1 user = session.query(User).all()[0] account = session.query(SocialAccount).all()[0] assert account in user.social_accounts assert account.access_token == "diff_acc_token" assert not account.id_token assert account.user == user assert account.account_email == email assert account.account_id == mock_id assert user.email == email assert user.first_name == first_name assert user.last_name == last_name assert not user.hashed_password assert user.role == UserRole.TENANT assert account.account_type == SocialAccountType.FACEBOOK
def test_delete_sqlatable(app_context: None, session: Session) -> None: """ Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. """ from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() assert session.query(Dataset).count() == 1 assert session.query(Table).count() == 1 assert session.query(Column).count() == 2 session.delete(sqla_table) session.flush() # test that dataset and dataset columns are also deleted # but the physical table and table columns are kept assert session.query(Dataset).count() == 0 assert session.query(Table).count() == 1 assert session.query(Column).count() == 1
def test_authenticated_user_link_facebook_acccount_to_existing_user_ok( session: Session, application: Application, tenant: Tenant): assert session.query(User).count() == 1 assert not session.query(SocialAccount).count() access_token = "mock_access_token" mockid = "mock_id" email = tenant.email first_name = "mockfname" last_name = "mocklname" role = UserRole.TENANT result = application.authenticate_facebook( session, mockid, access_token, email, first_name, last_name, role, authenticated_user=tenant, ) assert isinstance(result, bytes) assert session.query(User).count() == 1 assert session.query(SocialAccount).count() == 1 user = session.query(User).all()[0] account = session.query(SocialAccount).all()[0] assert account in user.social_accounts assert account.access_token == access_token assert not account.id_token assert account.user == user assert account.account_email == email assert user.email == email assert user.first_name == tenant.first_name assert user.last_name == tenant.last_name assert user.hashed_password assert user.role == UserRole.TENANT assert account.account_id == mockid assert account.account_type == SocialAccountType.FACEBOOK
def get_local_replies(session: Session) -> List[Reply]: """ Return all reply objects from the local database that are successful. """ return session.query(Reply).all()
def get_local_files(session: Session) -> List[File]: """ Return all file (a submitted file) objects from the local database. """ return session.query(File).all()
def get_db_object(self, session: Session) -> File: ''' Override DownloadJob. ''' return session.query(File).filter_by(uuid=self.uuid).one()
def manage_slas(self, dag: DAG, session: Session = None) -> None: """ Finding all tasks that have SLAs defined, and sending alert emails where needed. New SLA misses are also recorded in the database. We are assuming that the scheduler runs often, so we only check for tasks that should have succeeded in the past hour. """ self.log.info("Running SLA Checks for %s", dag.dag_id) if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): self.log.info( "Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return qry = (session.query( TI.task_id, func.max(TI.execution_date).label('max_ti')).with_hint( TI, 'USE INDEX (PRIMARY)', dialect_name='mysql').filter(TI.dag_id == dag.dag_id).filter( or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)).filter( TI.task_id.in_(dag.task_ids)).group_by( TI.task_id).subquery('sq')) max_tis: List[TI] = (session.query(TI).filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, TI.execution_date == qry.c.max_ti, ).all()) ts = timezone.utcnow() for ti in max_tis: task = dag.get_task(ti.task_id) if task.sla and not isinstance(task.sla, timedelta): raise TypeError( f"SLA is expected to be timedelta object, got " f"{type(task.sla)} in {task.dag_id}:{task.task_id}") dttm = dag.following_schedule(ti.execution_date) while dttm < timezone.utcnow(): following_schedule = dag.following_schedule(dttm) if following_schedule + task.sla < timezone.utcnow(): session.merge( SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)) dttm = dag.following_schedule(dttm) session.commit() # pylint: disable=singleton-comparison slas: List[SlaMiss] = ( session.query(SlaMiss).filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all()) # pylint: enable=singleton-comparison if slas: # pylint: disable=too-many-nested-blocks sla_dates: List[datetime.datetime] = [ sla.execution_date for sla in slas ] fetched_tis: List[TI] = (session.query(TI).filter( TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id).all()) blocking_tis: List[TI] = [] for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) blocking_tis.append(ti) else: session.delete(ti) session.commit() task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) blocking_task_list = "\n".join(ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications email_sent = False notification_sent = False if dag.sla_miss_callback: # Execute the alert callback self.log.info('Calling SLA miss callback') try: dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: # pylint: disable=broad-except self.log.exception( "Could not call sla_miss_callback for DAG %s", dag.dag_id) email_content = f"""\ Here's a list of tasks that missed their SLAs: <pre><code>{task_list}\n<code></pre> Blocking tasks: <pre><code>{blocking_task_list}<code></pre> Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} """ tasks_missed_sla = [] for sla in slas: try: task = dag.get_task(sla.task_id) except TaskNotFound: # task already deleted from DAG, skip it self.log.warning( "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id) continue tasks_missed_sla.append(task) emails: Set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): emails |= set(get_email_address_list(task.email)) elif isinstance(task.email, (list, tuple)): emails |= set(task.email) if emails: try: send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) email_sent = True notification_sent = True except Exception: # pylint: disable=broad-except Stats.incr('sla_email_notification_failure') self.log.exception( "Could not send SLA Miss email notification for DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: for sla in slas: sla.email_sent = email_sent sla.notification_sent = True session.merge(sla) session.commit()
#!/usr/bin/python3 """ List all state objects using sqlalchemy """ if __name__ == '__main__': from sys import argv from sqlalchemy import create_engine from sqlalchemy.orm.session import sessionmaker, Session from model_state import Base, State username = '******'.format(argv[1]) password = '******'.format(argv[2]) db_name = '{}'.format(argv[3]) engine = create_engine('mysql+mysqldb://{}:{}@localhost:3306/{}'.format( username, password, db_name)) Session = sessionmaker(bind=engine) session = Session() for state in session.query(State).order_by(State.id): if 'a' in state.name: print('{}: {}'.format(state.id, state.name))
def set_dag_run_state_to_failed( *, dag: DAG, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, commit: bool = False, session: SASession = NEW_SESSION, ): """ Set the dag run for a specific execution date or run_id and its running task instances to failed. :param dag: the DAG of which to alter state :param execution_date: the execution date from which to start looking(deprecated) :param run_id: the DAG run_id to start looking from :param commit: commit DAG and tasks to be altered to the database :param session: database session :return: If commit is true, list of tasks that have been updated, otherwise list of tasks that will be updated :raises: AssertionError if dag or execution_date is invalid """ if not exactly_one(execution_date, run_id): return [] if not dag: return [] if execution_date: if not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") dag_run = dag.get_dagrun(execution_date=execution_date) if not dag_run: raise ValueError( f'DagRun with execution_date: {execution_date} not found') run_id = dag_run.run_id if not run_id: raise ValueError(f'Invalid dag_run_id: {run_id}') # Mark the dag run to failed. if commit: _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session) # Mark only RUNNING task instances. task_ids = [task.task_id for task in dag.tasks] tis = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.task_id.in_(task_ids), TaskInstance.state.in_(State.running), ) task_ids_of_running_tis = [task_instance.task_id for task_instance in tis] tasks = [] for task in dag.tasks: if task.task_id not in task_ids_of_running_tis: continue task.dag = dag tasks.append(task) return set_state(tasks=tasks, dag_run_id=run_id, state=State.FAILED, commit=commit, session=session)
def get_local_sources(session: Session) -> List[Source]: """ Return all source objects from the local database, newest first. """ return session.query(Source).order_by(desc(Source.last_updated)).all()
def find_new_files(session: Session) -> List[File]: return session.query(File).filter_by(is_downloaded=False).all()
def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply], session: Session, data_dir: str) -> None: """ * Existing replies are updated in the local database. * New replies have an entry created in the local database. * Local replies not returned in the remote replies are deleted from the local database unless they are pending or failed. If a reply references a new journalist username, add them to the database as a new user. """ local_uuids = {reply.uuid for reply in local_replies} for reply in remote_replies: if reply.uuid in local_uuids: local_reply = [r for r in local_replies if r.uuid == reply.uuid][0] user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session) local_reply.journalist_id = user.id local_reply.size = reply.size local_uuids.remove(reply.uuid) logger.debug('Updated reply {}'.format(reply.uuid)) else: # A new reply to be added to the database. source_uuid = reply.source_uuid source = session.query(Source).filter_by(uuid=source_uuid)[0] user = find_or_create_user( reply.journalist_uuid, reply.journalist_username, session) nr = Reply(uuid=reply.uuid, journalist_id=user.id, source_id=source.id, filename=reply.filename, size=reply.size) session.add(nr) # All replies fetched from the server have succeeded in being sent, # so we should delete the corresponding draft locally if it exists. try: draft_reply_db_object = session.query(DraftReply).filter_by( uuid=reply.uuid).one() update_draft_replies(session, draft_reply_db_object.source.id, draft_reply_db_object.timestamp, draft_reply_db_object.file_counter, nr.file_counter) session.delete(draft_reply_db_object) except NoResultFound: pass # No draft locally stored corresponding to this reply. logger.debug('Added new reply {}'.format(reply.uuid)) # The uuids remaining in local_uuids do not exist on the remote server, so # delete the related records. replies_to_delete = [r for r in local_replies if r.uuid in local_uuids] for deleted_reply in replies_to_delete: delete_single_submission_or_reply_on_disk(deleted_reply, data_dir) session.delete(deleted_reply) logger.debug('Deleted reply {}'.format(deleted_reply.uuid)) session.commit()
def get_by_id(db_session: Session, shopping_list_id: int) -> Optional[models.List]: return db_session.query( models.List).filter(models.List.id == shopping_list_id).one()
def get_by_name(self, db: Session, name: str) -> Optional[SubjectsModel]: return db.query(SubjectsModel).filter( SubjectsModel.name == name).first()
def get_log( *, dag_id: str, dag_run_id: str, task_id: str, task_try_number: int, full_content: bool = False, token: Optional[str] = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get logs for specific task instance""" key = current_app.config["SECRET_KEY"] if not token: metadata = {} else: try: metadata = URLSafeSerializer(key).loads(token) except BadSignature: raise BadRequest("Bad Signature. Please use only the tokens provided by the API.") if metadata.get('download_logs') and metadata['download_logs']: full_content = True if full_content: metadata['download_logs'] = True else: metadata['download_logs'] = False task_log_reader = TaskLogReader() if not task_log_reader.supports_read: raise BadRequest("Task log handler does not support read logs.") ti = ( session.query(TaskInstance) .filter( TaskInstance.task_id == task_id, TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, ) .join(TaskInstance.dag_run) .one_or_none() ) if ti is None: metadata['end_of_log'] = True raise NotFound(title="TaskInstance not found") dag = current_app.dag_bag.get_dag(dag_id) if dag: try: ti.task = dag.get_task(ti.task_id) except TaskNotFound: pass return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json']) # return_type would be either the above two or None logs: Any if return_type == 'application/json' or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs token = URLSafeSerializer(key).dumps(metadata) return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) return Response(logs, headers={"Content-Type": return_type})
def source_exists(session: Session, source_uuid: str) -> bool: try: session.query(Source).filter_by(uuid=source_uuid).one() return True except NoResultFound: return False
#!/usr/bin/python3 """ This script that prints all State objects that contain the letter a from the database hbtn_0e_6_usa. """ from sqlalchemy import (create_engine) from sqlalchemy.orm.session import Session from model_state import Base, State import sys if __name__ == '__main__': engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format( sys.argv[1], sys.argv[2], sys.argv[3]), pool_pre_ping=True) Base.metadata.create_all(engine) session = Session(bind=engine) a = session.query(State).filter(State.name.contains('a')).order_by( State.id).all() for i in a: print("{}: {}".format(i.id, i.name)) session.close()
def get_file(session: Session, uuid: str) -> File: return session.query(File).filter_by(uuid=uuid).one()
def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["TaskInstance"]: """Create the mapped task instances for mapped task. :return: The mapped task instances, in ascending order by map index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook total_length = functools.reduce( operator.mul, self._get_map_lengths(run_id, session=session).values()) state: Optional[TaskInstanceState] = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index == -1, or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ).one_or_none()) ret: List[TaskInstance] = [] if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length < 1: # If the upstream maps this to a zero-length value, simply marked the # unmapped task instance as SKIPPED (if needed). self.log.info( "Marking %s as SKIPPED since the map has %d values to expand", unmapped_ti, total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED session.flush() return ret # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 state = unmapped_ti.state self.log.debug("Updated in place to become %s", unmapped_ti) ret.append(unmapped_ti) indexes_to_map = range(1, total_length) else: # Only create "missing" ones. current_max_mapping = (session.query( func.max(TaskInstance.map_index)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, ).scalar()) indexes_to_map = range(current_max_mapping + 1, total_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) # type: ignore self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) ti.task = self ret.append(ti) # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index >= total_length, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() return ret
def get_message(session: Session, uuid: str) -> Message: return session.query(Message).filter_by(uuid=uuid).one()
def find(dag_id: Optional[Union[str, List[str]]] = None, run_id: Optional[str] = None, execution_date: Optional[datetime] = None, state: Optional[str] = None, external_trigger: Optional[bool] = None, no_backfills: Optional[bool] = False, run_type: Optional[DagRunType] = None, session: Session = None, execution_start_date=None, execution_end_date=None): """ Returns a set of dag runs for the given search criteria. :param dag_id: the dag_id or list of dag_id to find dag runs for :type dag_id: str or list[str] :param run_id: defines the run id for this dag run :type run_id: str :param run_type: type of DagRun :type run_type: airflow.utils.types.DagRunType :param execution_date: the execution date :type execution_date: datetime.datetime or list[datetime.datetime] :param state: the state of the dag run :type state: str :param external_trigger: whether this dag run is externally triggered :type external_trigger: bool :param no_backfills: return no backfills (True), return all (False). Defaults to False :type no_backfills: bool :param session: database session :type session: sqlalchemy.orm.session.Session :param execution_start_date: dag run that was executed from this date :type execution_start_date: datetime.datetime :param execution_end_date: dag run that was executed until this date :type execution_end_date: datetime.datetime """ DR = DagRun qry = session.query(DR) dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id if dag_ids: qry = qry.filter(DR.dag_id.in_(dag_ids)) if run_id: qry = qry.filter(DR.run_id == run_id) if execution_date: if isinstance(execution_date, list): qry = qry.filter(DR.execution_date.in_(execution_date)) else: qry = qry.filter(DR.execution_date == execution_date) if execution_start_date and execution_end_date: qry = qry.filter( DR.execution_date.between(execution_start_date, execution_end_date)) elif execution_start_date: qry = qry.filter(DR.execution_date >= execution_start_date) elif execution_end_date: qry = qry.filter(DR.execution_date <= execution_end_date) if state: qry = qry.filter(DR.state == state) if external_trigger is not None: qry = qry.filter(DR.external_trigger == external_trigger) if run_type: qry = qry.filter(DR.run_type == run_type.value) if no_backfills: qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB.value) dr = qry.order_by(DR.execution_date).all() return dr
def get_reply(session: Session, uuid: str) -> Reply: return session.query(Reply).filter_by(uuid=uuid).one()
def populate_for_human_evaluation(session: Session, method_to_result: Dict[str, Path]) -> None: if session.query(HumanEvaluation).first() is not None: return if session.query(GenerationResult).first() is not None: return method_names = list(method_to_result.keys()) + ['Gold'] token_delim = '|' with method_to_result['Base'].open(mode='r') as f: N = sum(1 for _ in f) - 1 f.seek(0) reader = csv.reader(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL) next(reader) for _ in tqdm(range(N)): fields = next(reader) article_id = fields[0] base_result = GenerationResult( article_id, 'Base', ''.join(number2kansuuzi(fields[4].split(token_delim)[:-1]))) session.add(base_result) gold_result = GenerationResult(article_id, 'Gold', None) session.add(gold_result) session.commit() for method_name in [m for m in method_names if m not in ['Base', 'Gold']]: if not method_to_result[method_name].is_file(): continue with method_to_result[method_name].open(mode='r') as f: N = sum(1 for _ in f) - 1 f.seek(0) reader = csv.reader(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL) next(reader) for _ in tqdm(range(N)): fields = next(reader) article_id = fields[0] result = GenerationResult( article_id, method_name, ''.join( number2kansuuzi(fields[4].split(token_delim)[:-1]))) session.add(result) session.commit() groups = session \ .query(GenerationResult.article_id) \ .filter(GenerationResult.result != '') \ .group_by(GenerationResult.article_id) \ .all() orderings = list(permutations([m for m in method_names])) for group in groups: i = random.randint(0, math.factorial(len(method_names)) - 1) h = HumanEvaluation(group.article_id, ordering=orderings[i]) session.add(h) session.commit() results = session.query(HumanEvaluation).all() SAMPLE_SIZE = 100 n_results = len(results) if n_results >= SAMPLE_SIZE: sample_indices = random.sample(range(n_results), SAMPLE_SIZE) for i, result in enumerate(results): result.is_target = i in sample_indices session.commit()
def get_local_messages(session: Session) -> List[Message]: """ Return all submission objects from the local database. """ return session.query(Message).all()