def index(c: config.Config, db: sa.orm.Session, name: str): for s in get_sources(c, name): if s is None: click.echo( message=f'Source name={name} not found in config.', err=True, ) exit(1) source_model, _ = crud.get_or_create(db, models.Source, name=s.name) reindex = models.Reindex(source=source_model) db.add(reindex) db.add(source_model) db.commit() buckets = [ crud.upsert_object(db, b, {}, reindex, source_model) for b in sources.get_module(s.type).scan_for_buckets(s) ] bucketnames = ' '.join([b.name for b in buckets]) click.echo( f'reindexing {len(buckets)} bucket(s) from source={s.name}: {bucketnames}' ) nodes = [] for bucket in buckets: nodes += crud.get_nodes(db, bucket) with click.progressbar(crud.index_source(db, s, reindex), length=len(nodes)) as bar: for b in bar: pass
def test_relationship(ssn: sa.orm.Session): """ Test getting historical relationship values through an InstanceHistoryProxy """ # Prepare ssn.add(User(id=1, name='John', age=18)) ssn.add(User(id=2, name='Jack', age=18)) ssn.add(Article(id=1, title='Python', author_id=1)) ssn.commit() # Users john = ssn.query(User).get(1) jack = ssn.query(User).get(2) # Article article: Article = ssn.query(Article).get(1) old_article: Article = InstanceHistoryProxy(article) # noqa assert article.author == john # load it assert old_article.author == john # works # Modify article.author = jack assert old_article.author == john # still works # Flush ssn.flush() assert old_article.author == john # still works
def _insert_default_assets(session: sa.orm.Session): assets = [ # currencies ("USD", "US Dolar", True, "currency"), ("EUR", "Euros", True, "currency"), ("JPY", "Japanese Yen", True, "currency"), ("CNY", "Chinese Yuan", True, "currency"), ("CHF", "Swiss Franc", True, "currency"), ("BRL", "Brazilian Real", True, "currency"), ("BTC", "Bitcoin", True, "currency"), ("ETH", "Ethereum", True, "currency"), ("XMR", "Monero", True, "currency"), ("ADA", "Cardano", True, "currency"), ("USDT", "Tether", True, "currency"), ] for asset_item in assets: if isinstance(asset_item, tuple): asset_item = { k: v for k, v in zip( ("name", "description", "is_active", "type"), asset_item ) } asset_db = models.Asset(**asset_item) session.add(asset_db) session.commit()
def sample_users_articles(ssn: sa.orm.Session): """ Create a database with sample users and articles """ # Make some users ssn.add_all([ User(name='John', login='******', passwd='1', age=18, articles=[ Article(title='Jam'), Article(title='Jeep'), ]), User(name='Mark', login='******', passwd='2', age=20, articles=[ Article(title='Map'), Article(title='Mop'), ]), User(name='Nick', login='******', passwd='3', age=25, articles=[ Article(title='Nap'), Article(title='Nil'), ]), User(name='Kate', login='******', passwd='4', age=30, articles=[]), User(name='Cary', login='******', passwd='5', age=35, articles=[]), ]) ssn.commit()
def process_triggers(session: sqlalchemy.orm.Session, characters: List[Character], area: models.Area): """ Find triggers for an area, print them, and potentially resolve them. """ triggers = session.query(models.Trigger).filter( (~models.Trigger.resolved) & (models.Trigger.area_id == area.id) & (models.Trigger.party_lvl_min <= min(c.level for c in characters)) & (models.Trigger.party_size >= len(characters)) ).all() events = [t.event for t in triggers] for event in events: print('Event: {}'.format(event.name)) print(event.description) while 1: try: resolved_bool = prompt_bool('Resolved ?', True) except MalformedBoolException: continue break if resolved_bool: for t in event.triggers: t.resolved = True # TODO(jalex): Does this commit the triggers attached to the events? session.add_all(events) session.commit()
def harvest_users_from_tweets(session: sqlalchemy.orm.Session, FLUSH_LIMIT=10, startTweet=None): """This iterates through tweets from the database and harvests user data from them""" users = 0 lastTweetId = None tweetIter = tweets_with_other_data_generator(session) try: while True: tweet = next(tweetIter) user = update_or_create_user_from_tweet(tweet, session) users += 1 lastTweetId = tweet.tweetID if users % FLUSH_LIMIT == 0: print('flushing at %s users' % users) session.commit() except StopIteration: print("%s users created or updated" % users) session.commit() finally: print("Last processed tweet %s" % lastTweetId) # session.commit() session.close()
def process_task(config: ConfigHolder, task: Task, session: sqlalchemy.orm.Session): try: task.status = Task.STATUS_OPTIONS.PROCESSING session.commit() for subtask in tqdm.tqdm(task.children, desc=f"Task {task.id}: Processing subtasks"): if subtask.status not in [ Task.STATUS_OPTIONS.ERROR, Task.STATUS_OPTIONS.INTERRUPTED, Task.STATUS_OPTIONS.FINISHED ]: subconfig = ConfigHolder(subtask) if config.local: subconfig.local = True process_task(subconfig, subtask, session) to_process = [job for job in task.jobs if job.solution is None] process_jobs(to_process, config, session) task.status = Task.STATUS_OPTIONS.FINISHED except Exception as e: print(e) to_process = [job for job in task.jobs if job.solution is None] if str(e).lower( ) != "Backend does not support on_message callback".lower( ) and to_process: task.status = Task.STATUS_OPTIONS.ERROR task.error_message = str(e) if is_debug_env(): raise e else: task.status = Task.STATUS_OPTIONS.FINISHED finally: session.commit()
def bulk_insert_chunk(data_df: dt.Frame, table: sqlalchemy.Table, session: sqlalchemy.orm.Session, chunksize: int = 1000000, start_index: int = 0) -> None: """ Executes a bulk insert statement from `data_df` to `table` with approximately `chunksize` rows per `session.bulk_insert_mappings` call, starting from `start_index` of `data_df`. :param data_df: The datatable to write from. :param table: The SQLAlchemy table to write to. :param session: The SQLAlchemy session object to write to. :param chunksize: How many rows, approximately, to insert per iteration. :param start_index: The table row index to start writing from. If using id column values from the databse, this should be id - 1 to match the 0 based indexing of Python vs the 1 based indexing of SQL tables. :return: None, write to database. """ filter_df = data_df[start_index:, :] nchunks = math.ceil((filter_df.shape[0] + 1) / chunksize) chunk_array = np.array_split(np.arange(filter_df.shape[0]), nchunks) index_tuple_list = [(int(np.min(x)), int(np.max(x))) for x in chunk_array] for idx in tqdm(index_tuple_list, colour='magenta'): df = filter_df[idx[0]:(idx[1] + 1), :].to_pandas() row_dict = create_records(df) session.bulk_insert_mappings(table, row_dict, render_nulls=True) session.commit()
async def upload_avatar(file: UploadFile = File(...), current_user: User = Depends(get_current_user), db: sa.orm.Session = Depends(get_db)): """ Uploads avatar and saves them :param file: File object :param current_user: Current authentificated user :param db: Session instance :return: Upload info """ path = os.path.join('static/user_avatars', file.filename) with open(path, 'wb+') as buffer: shutil.copyfileobj(file.file, buffer) if current_user.avatar: if os.path.exists(current_user.avatar): os.remove(current_user.avatar) try: current_user.avatar = path db.add(current_user) db.commit() except Exception as e: print(e) db.rollback() return {'filename': file.filename, 'path': path}
def _insert_default_coa(session: sa.orm.Session): # noqa: C901 # the testing chart of accounts coa = """ root # 1 Assets # 2 Cash # 3 Receivables # 4 Inventory # 5 Liabilities # 6 Payables # 7 Shares Issued # 8 Retained Earnings # 9 Income # 10 Trade # 11 Interest # 12 Expenses # 13 Fees # 14 Broker # 15 Administration # 16 Tax # 17 Other # 18 """ coa = [line for line in coa.splitlines() if line.strip() != ""] coa = [line.split("#")[0].rstrip() for line in coa] def _get_level(coa, line): line_str = coa[line] level = len(line_str) - len(line_str.lstrip()) - 4 level = level // 4 return level def _insert_next(coa, line, parent_id, curr_level, last_id): while line < len(coa) - 1: line += 1 now_level = _get_level(coa, line) name = coa[line].strip() if now_level == curr_level: # sibling account last_id += 1 acc = models.Account( id=last_id, name=name, parent_id=parent_id ) session.add(acc) elif now_level == curr_level + 1: # child line -= 1 line, last_id = _insert_next( coa=coa, line=line, parent_id=last_id, curr_level=now_level, last_id=last_id, ) elif now_level < curr_level: # go back one level return line - 1, last_id return line, last_id root = models.Account(id=1, name=coa[0].strip(), parent_id=None) session.add(root) _insert_next(coa=coa, line=0, parent_id=1, curr_level=1, last_id=1) session.commit()
def process_jobs(jobs: List[TaskJobs], config: ConfigHolder, session: sqlalchemy.orm.Session): if not jobs: return processbar = tqdm.tqdm(total=len(jobs), desc=f"Task {jobs[0].task_id}: Process jobs") on_message = OnMessageCB(progressbar=processbar) # ToDo: To speed up solving time, maybe use bulksolve slice_size, slice_amount = _get_slicing(jobs, config.slice_size) slices = [(i * slice_size, (i + 1) * slice_size) for i in range(slice_amount - 1)] if slice_amount > 0: slices.append(tuple([(slice_amount - 1) * slice_size, len(jobs)])) solver_args = config.solver_args if "time_limit" in solver_args: time_limit = solver_args["time_limit"] else: time_limit = 900 if hasattr(config, "local") and config.local: for job in jobs: sol = solve( job.graph, config.solver, solver_config=config.solver_args, solve_config={ "start_solution": (None if job.prev_job is None else job.prev_job.solution.order), "time_limit": (time_limit if job.prev_job is None else time_limit - job.prev_job.solution.runtime) }) job.solution = sol processbar.update() session.commit() else: for start, end in slices: results = group( solve. s(job.graph, config.solver, solver_config=config.solver_args, solve_config={ "start_solution": (None if job.prev_job is None else job. prev_job.solution.order), "time_limit": ( time_limit if job.prev_job is None else time_limit - job.prev_job.solution.runtime) }) for job in jobs[start:end])().get(on_message=on_message) for job, result in zip(jobs[start:end], results): result.graph = job.graph if job.prev_job is not None: result.runtime = float(result.runtime) + float( job.prev_job.solution.runtime) job.solution = result if session: session.commit()
def init_characters(session: sqlalchemy.orm.Session) -> CharacterGroups: """ Initialize the party's characters. Returns: Character groups. """ characters = [] while 1: name = input('Character name: ') if not name: break chars = session.query( models.Character).filter(models.Character.name == name).all() if not chars: try: new_char = prompt_bool('New character?', False) except MalformedBoolException: continue if not new_char: continue player_name = input('Player name: ') char = models.Character(name=name, player_name=player_name) else: try: idx = prompt_choices( chars, lambda i, char: ' {} - {} ({})'.format( i, char.name, char.player_name)) except MalformedInputException: continue char = chars[idx] default_lvl = char.level or 1 try: char.level = int( input('Level: [{}] '.format(default_lvl)) or default_lvl) except ValueError: print('Supply integer level.') continue characters.append(char) character_groups = init_locations(session, characters) session.add_all(x for sl in character_groups for x in sl) session.commit() return character_groups
def transaction_cm(session: sa.orm.Session) -> tp.ContextManager[None]: """Provide a transactional scope around a series of operations.""" try: yield session.commit() except Exception: session.rollback() raise finally: session.close()
def test_property(ssn: sa.orm.Session): """ Test getting historical @property values through an InstanceHistoryProxy """ # Prepare ssn.add(User(id=1, name='John', age=18)) ssn.commit() # Load user: User = ssn.query(User).get(1) old_user: User = InstanceHistoryProxy(user) # noqa # @property access assert user.age_in_100_years == 118 assert old_user.age_in_100_years == 118 # Modify user.age = 20 assert old_user.age_in_100_years == 118 # still good
def create_user( user: CreateUserSchema, db: sa.orm.Session = get_db, ) -> UserSchema: """ Create new user. """ user = user.dict(exclude_unset=True) user["password"] = hash_password(user["password"]) user = User(**user) db.add(user) try: db.commit() except sa.exc.IntegrityError: db.rollback() raise HTTPException( status_code=400, detail="A user with this email already exists.", ) return user
def update_user( updated_user: UpdateUserSchema, user_id: int = Path(...), db: sa.orm.Session = get_db, ) -> UserSchema: """ Update a user. """ user = db.get(User, user_id) if user is None: raise HTTPException(status_code=404, detail="User not found") updated_user = updated_user.dict(exclude_unset=True) try: updated_user["password"] = hash_password(updated_user["password"]) except KeyError: pass user.update(updated_user) db.commit() return user
def create_product( product: CreateProductSchema, db: sa.orm.Session = get_db, ) -> ProductSchema: """ Create new Product. """ product = product.dict(exclude_unset=True) if "image" in product: product["image"] = b64decode(product["image"]) product = Product(**product) db.add(product) try: db.commit() except sa.exc.IntegrityError: db.rollback() raise HTTPException( status_code=400, detail="A product with that name already exists.", ) return product
def test_columns(ssn: sa.orm.Session): """ Simple test of InstanceHistoryProxy with columns """ # Prepare ssn.add(User(id=1, name='John', age=18)) ssn.commit() # Check initial state user: User = ssn.query(User).get(1) old_user: User = InstanceHistoryProxy(user) # noqa def old_user_is_correct(): assert old_user.id == 1 assert old_user.name == 'John' assert old_user.age == 18 # Modify user.id = 1000 user.name = 'CHANGED' user.age = 1800 old_user_is_correct() # still good # Flush ssn.flush() old_user_is_correct() # still good
def process_task(config: ConfigHolder, task: Task, session: sqlalchemy.orm.Session): try: task.status = Task.STATUS_OPTIONS.PROCESSING session.commit() for subtask in tqdm.tqdm(task.children, desc=f"Task {task.id}: Processing subtasks"): if subtask.status not in [Task.STATUS_OPTIONS.ERROR, Task.STATUS_OPTIONS.INTERRUPTED, Task.STATUS_OPTIONS.PROCESSING, Task.STATUS_OPTIONS.FINISHED]: process_task(ConfigHolder(subtask), subtask, session) to_process = [job for job in task.jobs if job.solution is None] process_jobs(to_process, config, session) task.status = Task.STATUS_OPTIONS.FINISHED session.commit() except Exception as e: print(e) task.status = Task.STATUS_OPTIONS.ERROR task.error_message = str(e) session.commit()
def insert_package_graph(session: sqlalchemy.orm.Session, task_data: Dict) -> None: link_ids = [] for task_dep in task_data.get("dependencies", []): add_new_package_version(session, task_dep) session.commit() parent_package_id = get_package_version_id_query(session, task_dep).first() for dep in task_dep.get("dependencies", []): # is fully qualified semver for npm (or file: or github: url), semver for yarn name, version = dep.rsplit("@", 1) child_package_id = get_package_version_id_query( session, dict(name=name, version=version)).first() link_id = get_package_version_link_id_query( session, (parent_package_id, child_package_id)).one_or_none() if not link_id: session.add( PackageLink( child_package_id=child_package_id, parent_package_id=parent_package_id, )) session.commit() link_id = get_package_version_link_id_query( session, (parent_package_id, child_package_id)).first() link_ids.append(link_id) session.add( PackageGraph( root_package_version_id=get_package_version_id_query( session, task_data["root"]).first() if task_data["root"] else None, link_ids=link_ids, package_manager="yarn" if "yarn" in task_data["command"] else "npm", package_manager_version=None, )) session.commit()
def clear(session: sa.orm.Session): "Just delete entries for all tables" for table in models.Base.metadata.tables.values(): session.query(table).delete() session.commit()
def try_to_commit(session: sa.orm.Session): try: session.commit() except sa.exc.IntegrityError: session.rollback() raise HTTPException(status_code=422, detail=tb.format_exc(limit=0))
def insert_package_audit(session: sqlalchemy.orm.Session, task_data: Dict) -> None: is_yarn_cmd = bool("yarn" in task_data["command"]) # NB: yarn has .advisory and .resolution # the same advisory JSON (from the npm DB) is # at .advisories{k, v} for npm and .advisories[].advisory for yarn advisories = ((item.get("advisory", None) for item in task_data.get("advisories", [])) if is_yarn_cmd else task_data.get("advisories", dict()).values()) non_null_advisories = (adv for adv in advisories if adv) for advisory in non_null_advisories: advisory_fields = extract_nested_fields( advisory, { "package_name": ["module_name"], "npm_advisory_id": ["id"], "vulnerable_versions": ["vulnerable_versions"], "patched_versions": ["patched_versions"], "created": ["created"], "updated": ["updated"], "url": ["url"], "severity": ["severity"], "cves": ["cves"], "cwe": ["cwe"], "exploitability": ["metadata", "exploitability"], "title": ["title"], }, ) advisory_fields["cwe"] = int(advisory_fields["cwe"].lower().replace( "cwe-", "")) advisory_fields["language"] = "node" advisory_fields["vulnerable_package_version_ids"] = [] get_node_advisory_id_query( session, advisory_fields).one_or_none() or session.add( Advisory(**advisory_fields)) session.commit() # TODO: update other advisory fields too impacted_versions = set( finding.get("version", None) for finding in advisory.get("findings", []) if finding.get("version", None)) db_advisory = (session.query( Advisory.id, Advisory.vulnerable_package_version_ids).filter_by( language="node", url=advisory["url"]).first()) impacted_version_package_ids = list( vid for result in session.query(PackageVersion.id).filter( PackageVersion.name == advisory_fields["package_name"], PackageVersion.version.in_(impacted_versions), ).all() for vid in result) if len(impacted_versions) != len(impacted_version_package_ids): log.warning( f"missing package versions for {advisory_fields['package_name']!r}" f" in the db or misparsed audit output version:" f" {impacted_versions} {impacted_version_package_ids}") if db_advisory.vulnerable_package_version_ids is None: session.query(Advisory.id).filter_by(id=db_advisory.id).update( dict(vulnerable_package_version_ids=list())) # TODO: lock the row? vpvids = set( list( session.query(Advisory).filter_by( id=db_advisory.id).first().vulnerable_package_version_ids)) vpvids.update(set(impacted_version_package_ids)) session.query(Advisory.id).filter_by(id=db_advisory.id).update( dict(vulnerable_package_version_ids=sorted(vpvids))) session.commit()
def insert_npmsio_data(session: sqlalchemy.orm.Session, source: Generator[Dict[str, Any], None, None]) -> None: for line in source: fields = extract_nested_fields( line, { "package_name": ["collected", "metadata", "name"], "package_version": ["collected", "metadata", "version"], "analyzed_at": ["analyzedAt" ], # e.g. "2019-11-27T19:31:42.541Z" # overall score from .score.final on the interval [0, 1] "score": ["score", "final"], # score components on the interval [0, 1] "quality": ["score", "detail", "quality"], "popularity": ["score", "detail", "popularity"], "maintenance": ["score", "detail", "maintenance"], # score subcomponent/detail fields from .evaluation.<component>.<subcomponent> # generally frequencies and subscores are decimals between [0, 1] # or counts of downloads, stars, etc. # acceleration is signed (+/-) "branding": ["evaluation", "quality", "branding"], "carefulness": ["evaluation", "quality", "carefulness"], "health": ["evaluation", "quality", "health"], "tests": ["evaluation", "quality", "tests"], "community_interest": ["evaluation", "popularity", "communityInterest"], "dependents_count": ["evaluation", "popularity", "dependentsCount"], "downloads_acceleration": [ "evaluation", "popularity", "downloadsAcceleration", ], "downloads_count": ["evaluation", "popularity", "downloadsCount"], "commits_frequency": ["evaluation", "maintenance", "commitsFrequency"], "issues_distribution": [ "evaluation", "maintenance", "issuesDistribution", ], "open_issues": ["evaluation", "maintenance", "openIssues"], "releases_frequency": [ "evaluation", "maintenance", "releasesFrequency", ], }, ) fields[ "source_url"] = f"https://api.npms.io/v2/package/{fields['package_name']}" # only insert new rows if (session.query(NPMSIOScore.id).filter_by( package_name=fields["package_name"], package_version=fields["package_version"], analyzed_at=fields["analyzed_at"], ).one_or_none()): log.debug( f"skipping inserting npms.io score for {fields['package_name']}@{fields['package_version']}" f" analyzed at {fields['analyzed_at']}") else: session.add(NPMSIOScore(**fields)) session.commit() log.info( f"added npms.io score for {fields['package_name']}@{fields['package_version']}" f" analyzed at {fields['analyzed_at']}")
def insert_npm_registry_data( session: sqlalchemy.orm.Session, source: Generator[Dict[str, Any], None, None]) -> None: for line in source: # save version specific data for version, version_data in line["versions"].items(): fields = extract_nested_fields( version_data, { "package_name": ["name"], "package_version": ["version"], "shasum": ["dist", "shasum"], "tarball": ["dist", "tarball"], "git_head": ["gitHead"], "repository_type": ["repository", "type"], "repository_url": ["repository", "url"], "description": ["description"], "url": ["url"], "license_type": ["license"], "keywords": ["keywords"], "has_shrinkwrap": ["_hasShrinkwrap"], "bugs_url": ["bugs", "url"], "bugs_email": ["bugs", "email"], "author_name": ["author", "name"], "author_email": ["author", "email"], "author_url": ["author", "url"], "maintainers": ["maintainers"], "contributors": ["contributors"], "publisher_name": ["_npmUser", "name"], "publisher_email": ["_npmUser", "email"], "publisher_node_version": ["_nodeVersion"], "publisher_npm_version": ["_npmVersion"], }, ) # license can we a string e.g. 'MIT' # or dict e.g. {'type': 'MIT', 'url': 'https://github.com/jonschlinkert/micromatch/blob/master/LICENSE'} fields["license_url"] = None if isinstance(fields["license_type"], dict): fields["license_url"] = fields["license_type"].get("url", None) fields["license_type"] = fields["license_type"].get( "type", None) # looking at you [email protected].{3,4} with: # [{"name": "StrongLoop", "url": "http://strongloop.com/license/"}, "MIT"], if not ((isinstance(fields["license_type"], str) or fields["license_type"] is None) and (isinstance(fields["license_url"], str) or fields["license_url"] is None)): log.warning( f"skipping weird license format {fields['license_type']}") fields["license_url"] = None fields["license_type"] = None # published_at .time[<version>] e.g. '2014-05-23T21:21:04.170Z' (not from # the version info object) # where time: an object mapping versions to the time published, along with created and modified timestamps fields["published_at"] = get_in(line, ["time", version]) fields["package_modified_at"] = get_in(line, ["time", "modified"]) fields[ "source_url"] = f"https://registry.npmjs.org/{fields['package_name']}" if (session.query(NPMRegistryEntry.id).filter_by( package_name=fields["package_name"], package_version=fields["package_version"], shasum=fields["shasum"], tarball=fields["tarball"], ).one_or_none()): log.debug( f"skipping inserting npm registry entry for {fields['package_name']}@{fields['package_version']}" f" from {fields['tarball']} with sha {fields['shasum']}") else: session.add(NPMRegistryEntry(**fields)) session.commit() log.info( f"added npm registry entry for {fields['package_name']}@{fields['package_version']}" f" from {fields['tarball']} with sha {fields['shasum']}")
def test_does_not_lose_history(ssn: sa.orm.Session): """ Extensive test of InstanceHistoryProxy with query counters and lazy loads """ assert ssn.autoflush == False, 'this test relies on Session.autoflush=False' engine = ssn.get_bind() # Prepare ssn.add(User(id=1, name='John', age=18)) ssn.add(Article(id=1, title='Python', author_id=1)) ssn.commit() # === Test 1: ModelHistoryProxy does not lose history when flushing a session ssn.expunge_all( ) # got to reset; otherwise, the session might reuse loaded objects user = ssn.query(User).get(1) with ExpectedQueryCounter(engine, 0, 'Expected no queries here'): old_user_hist = InstanceHistoryProxy(user) # issues no queries # Modify user.name = 'CHANGED' # History works assert old_user_hist.name == 'John' # Flush ssn.flush() # History is NOT broken! assert old_user_hist.name == 'John' # Change another column after flush; history is still NOT broken! user.age = 1800 assert old_user_hist.age == 18 # correct # Undo ssn.rollback() # === Test 1: ModelHistoryProxy does not lose history when lazyloading a column ssn.expunge_all( ) # got to reset; otherwise, the session might reuse loaded objects user = ssn.query(User).options(load_only('name')).get(1) with ExpectedQueryCounter(engine, 0, 'Expected no queries here'): old_user_hist = InstanceHistoryProxy(user) # issues no queries user.name = 'CHANGED' assert old_user_hist.name == 'John' # Load a column with ExpectedQueryCounter(engine, 1, 'Expected 1 lazyload query'): user.age # get an unloaded column # History is NOT broken! assert old_user_hist.name == 'John' # === Test 2: ModelHistoryProxy does not lose history when lazyloading a one-to-many relationship ssn.expunge_all( ) # got to reset; otherwise, the session might reuse loaded objects user = ssn.query(User).get(1) with ExpectedQueryCounter(engine, 0, 'Expected no queries here'): old_user_hist = InstanceHistoryProxy(user) user.name = 'CHANGED' assert old_user_hist.name == 'John' # History works # Load a relationship with ExpectedQueryCounter(engine, 1, 'Expected 1 lazyload query'): list(user.articles) # History is NOT broken! assert old_user_hist.name == 'John' # === Test 3: ModelHistoryProxy does not lose history when lazyloading a one-to-one relationship ssn.expunge_all( ) # got to reset; otherwise, the session might reuse loaded objects article = ssn.query(Article).get(1) with ExpectedQueryCounter(engine, 0, 'Expected no queries here'): old_article_hist = InstanceHistoryProxy(article) article.title = 'CHANGED' assert old_article_hist.title == 'Python' # works # Load a relationship with ExpectedQueryCounter(engine, 1, 'Expected 1 lazyload query'): article.author # History is NOT broken! assert old_article_hist.title == 'Python' # works
def process_jobs(jobs: List[TaskJobs], config: ConfigHolder, session: sqlalchemy.orm.Session): if not jobs: return processbar = tqdm.tqdm(total=len(jobs), desc=f"Task {jobs[0].task_id}: Process jobs") def _on_message(body): if body["status"] in ['SUCCESS', 'FAILURE']: if body["status"] == 'FAILURE': print("Found an error:", body) try: processbar.update() except AttributeError: pass # ToDo: To speed up solving time, maybe use bulksolve slice_size, slice_amount = _get_slicing(jobs, config.slice_size) slices = [(i*slice_size, (i+1)*slice_size) for i in range(slice_amount-1)] if slice_amount > 0: slices.append(tuple([(slice_amount-1)*slice_size, len(jobs)])) solver_args = config.solver_args if "time_limit" in solver_args: time_limit = solver_args["time_limit"] else: time_limit = 900 if config.local: for job in jobs: sol = solve( job.graph, config.solver, solver_config=config.solver_args, solve_config={ "start_solution":(None if job.prev_job is None else job.prev_job.solution.order), "time_limit":(time_limit if job.prev_job is None else time_limit - job.prev_job.solution.runtime) } ) job.solution = sol processbar.update() session.commit() else: for start, end in slices: results = group(solve.s( job.graph, config.solver, solver_config=config.solver_args, solve_config={ "start_solution":(None if job.prev_job is None else job.prev_job.solution.order), "time_limit":(time_limit if job.prev_job is None else time_limit - job.prev_job.solution.runtime) } ) for job in jobs[start:end])().get(on_message=_on_message) # This is just for local testing #results = [solve( # job.graph, # config.solver, # solver_config=config.solver_args, # solve_config={"start_solution":(None if job.prev_job is None else job.prev_job.solution.order)} # ) # for job in jobs[start:end]] for job, result in zip(jobs[start:end], results): result.graph = job.graph if job.prev_job is not None: result.runtime = result.runtime + job.prev_job.runtime job.solution = result if session: session.commit()