def get_all_distributions(limit=100): """ Obtain distributions, partitioned by channels with up to ``limit`` results for each channel """ from splice.environment import Environment env = Environment.instance() dist_cte = (env.db.session.query( Distribution.channel_id, Distribution.url, Distribution.created_at, func.row_number().over( partition_by=Distribution.channel_id, order_by=Distribution.created_at.desc()).label('row_num'))).cte() stmt = (env.db.session.query( dist_cte.c.channel_id, dist_cte.c.url, dist_cte.c.created_at).filter(dist_cte.c.row_num <= limit).order_by( dist_cte.c.created_at.desc())) rows = stmt.all() channels = {} for row in rows: c_dists = channels.setdefault(row.channel_id, []) c_dists.append({'url': row.url, 'created_at': row.created_at}) return channels
def get_nth_unfinished_game(session, n): row_number_column = func.row_number().over(order_by=Game.date.desc()).label('row_number') query = session.query(Game) #query = query.filter(Foo.time_key <= time_key) query = query.add_column(row_number_column) query = query.from_self().filter(row_number_column == n) return query
def func_window(): """ mysql 不支持 """ s = select([users.c.id, func.row_number().over(order_by=users.c.name)]) print str(s) conn = get_engine().connect() for row in conn.execute(s).fetchall(): print row
def get_by_reactions(cls): '''Gets all posts and ordered by their reactions.''' posts = cls.get_reactions().subquery() return db.session.query( Post.id, func.row_number().over( order_by=posts.c.reactions).label('sequence')).join( cls, posts.c.id == Post.id)
def get_upcoming_distributions(limit=100, leniency_minutes=15, include_past=False): """ Obtain distributions, partitioned by channels with up to ``limit`` results for each channel :leniency_minutes: have a leniency in minutes up to the present when looking for distributions :include_past: always return all past distributions """ from splice.environment import Environment env = Environment.instance() # getting around PEP8 E712 warning. This is necessary for SQLAlchemy false_value = False dist_cte = ( env.db.session .query( Distribution.id, Distribution.channel_id, Distribution.url, Distribution.created_at, Distribution.scheduled_start_date, func.row_number().over( partition_by=Distribution.channel_id, order_by=Distribution.scheduled_start_date.asc()) .label('row_num') ) .filter(Distribution.deployed == false_value)) if not include_past: min_dt = datetime.utcnow() - timedelta(minutes=leniency_minutes) dist_cte = ( dist_cte .filter(Distribution.scheduled_start_date >= min_dt)) dist_cte = dist_cte.cte() stmt = ( env.db.session .query( dist_cte.c.id, dist_cte.c.channel_id, dist_cte.c.url, dist_cte.c.created_at, dist_cte.c.scheduled_start_date) .filter(dist_cte.c.row_num <= limit) .order_by(dist_cte.c.scheduled_start_date.asc()) ) rows = stmt.all() channels = {} for row in rows: c_dists = channels.setdefault(row.channel_id, []) c_dists.append({'id': row.id, 'url': row.url, 'created_at': row.created_at, 'scheduled_at': row.scheduled_start_date}) return channels
def get_corptx_ids(tickers, release_window, release_count, limit, tick_limit, sd, ed): """ Gets sample_count ids for ticker with at least release_count earnings within release_window. params ticker (str): corp_tx company_symbol release_window (int): relevant earnings_release_date window release_count (int): min number of earnings releases during release_window limit (int): samples return limit returns ids (1D np arr): matching ids """ # subquery corp_txs for release_count financial releases during window days_from_release = CorpTx.trans_dt - Financial.earnings_release_date fin_count = func.count(Financial.id).label('fin_count') window_stmt = db.query(CorpTx.id) \ .distinct(CorpTx.cusip_id, CorpTx.trans_dt) \ .join(Financial, Financial.ticker == CorpTx.company_symbol) \ .filter(CorpTx.company_symbol.in_(tickers), CorpTx.close_yld > 0, CorpTx.close_yld <= 20.0, days_from_release <= release_window, days_from_release > 0) \ .group_by(CorpTx.id) \ .having(fin_count == release_count) \ .subquery('window_sq') # partition by row number rn = func.row_number() \ .over(partition_by=CorpTx.company_symbol, order_by=CorpTx.id).label('rn') sq = db.query(CorpTx.id, rn) \ .join(window_stmt, CorpTx.id == window_stmt.c.id) \ .join(EquityPx, and_(CorpTx.company_symbol == EquityPx.ticker, CorpTx.trans_dt == EquityPx.date)) \ .join(InterestRate, CorpTx.trans_dt == InterestRate.date) \ .join(Financial, Financial.ticker == CorpTx.company_symbol) \ .filter(days_from_release <= release_window, days_from_release > 0, CorpTx.trans_dt <= ed, CorpTx.trans_dt > sd).subquery('sq') s = db.query(CorpTx.id) \ .distinct(CorpTx.id, CorpTx.trans_dt) \ .join(sq, sq.c.id == CorpTx.id) \ .filter(sq.c.rn <= tick_limit*release_count) \ .order_by(CorpTx.trans_dt.asc()) \ .limit(limit) ids = db.execute(s).fetchall() return np.unique(np.array(ids).flatten())
def get_commodity_max_prices(self, commodity_id, days=30): session = self.get_session() """ select * from commoditiesMaxPrice a, ( select timestamp, max_price, max(sell_demand) as max_demand from ( select b.timestamp, max_price, sell_demand from commoditiesMaxPrice a, ( select timestamp, max(sell_price) as max_price from commoditiesMaxPrice group by timestamp ) b where a.timestamp = b.timestamp and max_price = sell_price ) group by timestamp, max_price ) b, markets m where sell_price = max_price and sell_demand = max_demand and a.timestamp = b.timestamp and a.market_id = m.id order by timestamp desc limit 30 ; select * from ( select *, row_number() over (partition by timestamp order by sell_price desc, sell_demand desc) as num from commoditiesMaxPrice where commodity_id = 144 ), markets where num=1 and id=market_id order by timestamp desc limit 30 ; """ ordered_by_loop = session.query( CommodityMaxPrice, func.row_number().over( partition_by=CommodityMaxPrice.timestamp, order_by=(CommodityMaxPrice.sell_price.desc(), CommodityMaxPrice.sell_demand.desc() )).label("num")).filter_by( commodity_id=commodity_id).subquery() res = session.query(CommodityMaxPrice).select_entity_from( ordered_by_loop).filter(ordered_by_loop.c.num <= 1).all() return [{ "price": p.sell_price, "demand": p.sell_demand, "date": p.timestamp, "commodity": { "id": p.commodity_id, "name": self.mapping.to_name_safe(p.commodity_id), "inaraLink": INARA_URL % p.commodity_id }, "market": { "id": p.market.id, "system": p.market.system, "station": p.market.station, }, } for p in res]
def posts_feed(user): latest = request.args.get('latest') top = request.args.get('top') cursor = request.args.get('cursor') items_per_page = current_app.config['ITEMS_PER_PAGE'] nextCursor = None query = '' try: followed_posts = user.get_followed_posts().subquery() posts_reactions = Post.get_reactions().subquery() top_followed_posts = db.session.query( followed_posts, func.row_number().over(order_by=posts_reactions.c.reactions). label('sequence')).outerjoin( posts_reactions, followed_posts.c.posts_id == posts_reactions.c.id).subquery() top_posts = db.session.query(Post, top_followed_posts.c.sequence).join( Post, top_followed_posts.c.posts_id == Post.id).order_by( top_followed_posts.c.sequence.desc()) latest_posts = db.session.query(Post, followed_posts.c.posts_id).join( Post, Post.id == followed_posts.c.posts_id).order_by( Post.created_on.desc()) except Exception as e: db.session.rollback() print(e) return server_error('An unexpected error occured, please try again.') if cursor == '0' and latest: query = latest_posts.limit(items_per_page + 1).all() elif cursor == '0' and top: query = top_posts.limit(items_per_page + 1).all() else: if latest: cursor = urlsafe_base64(cursor, from_base64=True) query = latest_posts.filter( Post.created_on < cursor).limit(items_per_page + 1).all() else: cursor = urlsafe_base64(cursor, from_base64=True) query = top_posts.filter( top_followed_posts.c.sequence < cursor).limit(items_per_page + 1).all() if len(query) > items_per_page: nextCursor = urlsafe_base64( query[items_per_page - 1][0].created_on.isoformat()) \ if latest else urlsafe_base64( str(query[items_per_page - 1][1])) return { 'data': [post[0].to_dict(user) for post in query[:items_per_page]], 'nextCursor': nextCursor }
def _query_editables(self, has_files, extensions): inner = ( db.session.query(EditingRevision.id, EditingRevision.editable_id) # only get revisions belonging to the correct event + editable type .filter( EditingRevision.editable.has( and_( Editable.contribution.has( and_( ~Contribution.is_deleted, Contribution.event_id == self.event.id, ) ), Editable.type == self.editable_type, ) ) ) # allow filtering by "is latest revision" later .add_columns( over( func.row_number(), partition_by=EditingRevision.editable_id, order_by=EditingRevision.created_dt.desc(), ).label('rownum') ) ).subquery() revision_query = ( db.session.query(EditingRevision.editable_id) .select_entity_from(inner) .filter(inner.c.rownum == 1) # only latest revision ) # filter by presence (or lack of) file types for file_type_id, present in has_files.items(): crit = self._make_revision_file_type_filter(inner, file_type_id) if not present: crit = ~crit revision_query = revision_query.filter(crit) # filter by having files with certain extensions for file_type_id, exts in extensions.items(): ext_filter = EditingRevisionFile.file.has(File.extension.in_(exts)) revision_query = revision_query.filter( self._make_revision_file_type_filter(inner, file_type_id, ext_filter) ) revision_query = revision_query.subquery() return (Editable.query .join(revision_query, revision_query.c.editable_id == Editable.id) .options(joinedload('contribution')) .all())
def get_signed_fields(session, doc_id: UUID): '''Get all fields that have already been filled''' rownum = func.row_number().over( partition_by=FieldUsage.field_id, order_by=FieldUsage.timestamp.desc()).label("row_number") subquery = (session.query(FieldUsage).join( File, isouter=True).join(Field).filter( Field.document_id == doc_id.bytes).filter( FieldUsage.fieldusage_type == config.FIELD_USAGE_TYPE['filled'] ).filter( Field.field_type == config.FIELD_TYPE['signature']).add_column( rownum).with_entities(File.filename, Field.field_name, rownum).subquery()) return (session.query(subquery).filter( subquery.c.row_number == 1).with_entities(subquery.c.filename, subquery.c.field_name).all())
def get_text_fields(session, doc_id: UUID): '''Get all non-signature fields that have been filled''' rownum = func.row_number().over( partition_by=FieldUsage.field_id, order_by=FieldUsage.timestamp.desc()).label("row_number") subquery = (session.query(FieldUsage).join(Field).filter( Field.document_id == doc_id.bytes).filter( FieldUsage.fieldusage_type == config.FIELD_USAGE_TYPE['filled']).filter( Field.field_type != config.FIELD_TYPE['signature']).add_column( rownum).with_entities(FieldUsage.data, Field.field_name, rownum).subquery()) return map( lambda x: (json.loads(x[0]).get('value'), x[1]), session.query(subquery).filter( subquery.c.row_number == 1).with_entities( subquery.c.data, subquery.c.field_name).all())
def get_all_distributions(limit=100): """ Obtain distributions, partitioned by channels with up to ``limit`` results for each channel """ from splice.environment import Environment env = Environment.instance() dist_cte = ( env.db.session .query( Distribution.channel_id, Distribution.url, Distribution.created_at, func.row_number().over( partition_by=Distribution.channel_id, order_by=Distribution.created_at.desc()) .label('row_num') ) ).cte() stmt = ( env.db.session .query( dist_cte.c.channel_id, dist_cte.c.url, dist_cte.c.created_at) .filter(dist_cte.c.row_num <= limit) .order_by(dist_cte.c.created_at.desc()) ) rows = stmt.all() channels = {} for row in rows: c_dists = channels.setdefault(row.channel_id, []) c_dists.append({'url': row.url, 'created_at': row.created_at}) return channels
def limit_groups(query, model, partition_by, order_by, limit=None, offset=0): """Limits the number of rows returned for each group This utility allows you to apply a limit/offset to grouped rows of a query. Note that the query will only contain the data from `model`; i.e. you cannot add additional entities. :param query: The original query, including filters, joins, etc. :param model: The model class for `query` :param partition_by: The column to group by :param order_by: The column to order the partitions by :param limit: The maximum number of rows for each partition :param offset: The number of rows to skip in each partition """ inner = query.add_columns(over(func.row_number(), partition_by=partition_by, order_by=order_by).label('rownum')).subquery() query = model.query.select_entity_from(inner) if limit: return query.filter(offset < inner.c.rownum, inner.c.rownum <= (limit + offset)) else: return query.filter(offset < inner.c.rownum)
def _limit_using_window_function(self, query): """ Apply a limit using a window function This approach enables us to limit the number of eagerly loaded related entities """ # Only do it when there is a limit if self.skip or self.limit: # First, add a row counter: query = query.add_columns( # for every group, count the rows with row_number(). func.row_number().over( # Groups are partitioned by self._window_over_columns, partition_by=self._window_over_columns, # We have to apply the same ordering from the outside query; # otherwise, the numbering will be undetermined order_by=self.mongoquery.handler_sort.compile_columns() ).label('group_row_n') # give it a name that we can use later ) # Now, make ourselves into a subquery query = query.from_self() # Well, it turns out that subsequent joins somehow work. # I have no idea how, but they do. # Otherwise, we would have had to ban using 'joins' after 'limit' in nested queries. # And apply the LIMIT condition using row numbers # These two statements simulate skip/limit using window functions if self.skip: query = query.filter(literal_column('group_row_n') > self.skip) if self.limit: query = query.filter( literal_column('group_row_n') <= ( (self.skip or 0) + self.limit)) # Done return query
class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=Integer)), text("select a, b, c from table where foo=:foo").bindparams( bindparam("foo", type_=Integer)), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=String)), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), column("z") > column("x"), column("x").in_([5, 7]), column("x").in_([10, 7, 8]), # note these two are mathematically equivalent but for now they # are considered to be different column("z") >= column("x"), column("x") <= column("z"), column("q").between(5, 6), column("q").between(5, 6, symmetric=True), column("q").like("somstr"), column("q").like("somstr", escape="\\"), column("q").like("somstr", escape="X"), ), lambda: ( table_a.c.a, table_a.c.a._annotate({"orm": True}), table_a.c.a._annotate({ "orm": True })._annotate({"bar": False}), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("a", table_a) }), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("b", table_a) }), table_a.c.a._annotate( { "orm": True, "parententity": MyEntity("b", select([table_a])) }), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("b", select([table_a]).where(table_a.c.a == 5)), }), ), lambda: ( table_a, table_a._annotate({"orm": True}), table_a._annotate({ "orm": True })._annotate({"bar": False}), table_a._annotate({ "orm": True, "parententity": MyEntity("a", table_a) }), table_a._annotate({ "orm": True, "parententity": MyEntity("b", table_a) }), table_a._annotate({ "orm": True, "parententity": MyEntity("b", select([table_a])) }), ), lambda: ( table("a", column("x"), column("y")), table("a", column("x"), column("y"))._annotate({"orm": True}), table("b", column("x"), column("y"))._annotate({"orm": True}), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(), ), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b, column("q")), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over(order_by=table_a.c.a, partition_by=table_a.c.b), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case(whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ]), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({ "wendy": "W", "jack": "J" }, value=table_a.c.a, else_="E"), case({ "wendy": "W", "jack": "J" }, value=table_a.c.b, else_="E"), case({ "wendy_w": "W", "jack": "J" }, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.b, table_a.c.a]).apply_labels(), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]).where(table_a.c.b == 5).where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update( nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]).where(table_a.c.b == 5).correlate_except( table_b), ), lambda: ( future_select(table_a.c.a), future_select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.a ), future_select(table_a.c.a).join_from(table_a, table_b, table_a.c.a == table_b.c.a), future_select(table_a.c.a).join_from(table_a, table_b), future_select(table_a.c.a).join_from(table_c, table_b), future_select(table_a.c.a).join( table_b, table_a.c.a == table_b.c.a).join( table_c, table_b.c.b == table_c.c.x), future_select(table_a.c.a).join(table_b), future_select(table_a.c.a).join(table_c), future_select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.b ), future_select(table_a.c.a).join(table_c, table_a.c.a == table_c.c.x ), ), lambda: ( select([table_a.c.a]).cte(), select([table_a.c.a]).cte(recursive=True), select([table_a.c.a]).cte(name="some_cte", recursive=True), select([table_a.c.a]).cte(name="some_cte"), select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), select([table_a.c.a]).cte(name="some_cte").union_all( select([table_a.c.a])), select([table_a.c.a]).cte(name="some_cte").union_all( select([table_a.c.b])), select([table_a.c.a]).lateral(), select([table_a.c.a]).lateral(name="bar"), table_a.tablesample(func.bernoulli(1)), table_a.tablesample(func.bernoulli(1), seed=func.random()), table_a.tablesample(func.bernoulli(1), seed=func.other_random()), table_a.tablesample(func.hoho(1)), table_a.tablesample(func.bernoulli(1), name="bar"), table_a.tablesample( func.bernoulli(1), name="bar", seed=func.random()), ), lambda: ( table_a.insert(), table_a.insert().values({})._annotate({"nocache": True}), table_b.insert(), table_b.insert().with_dialect_options(sqlite_foo="some value"), table_b.insert().from_select(["a", "b"], select([table_a])), table_b.insert().from_select(["a", "b"], select([table_a]).where(table_a.c.a > 5)), table_b.insert().from_select(["a", "b"], select([table_b])), table_b.insert().from_select(["c", "d"], select([table_a])), table_b.insert().returning(table_b.c.a), table_b.insert().returning(table_b.c.a, table_b.c.b), table_b.insert().inline(), table_b.insert().prefix_with("foo"), table_b.insert().with_hint("RUNFAST"), table_b.insert().values(a=5, b=10), table_b.insert().values(a=5), table_b.insert().values({ table_b.c.a: 5, "b": 10 })._annotate({"nocache": True}), table_b.insert().values(a=7, b=10), table_b.insert().values(a=5, b=10).inline(), table_b.insert().values([{ "a": 5, "b": 10 }, { "a": 8, "b": 12 }])._annotate({"nocache": True}), table_b.insert().values([{ "a": 9, "b": 10 }, { "a": 8, "b": 7 }])._annotate({"nocache": True}), table_b.insert().values([(5, 10), (8, 12)])._annotate({"nocache": True}), table_b.insert().values([(5, 9), (5, 12)])._annotate({"nocache": True}), ), lambda: ( table_b.update(), table_b.update().where(table_b.c.a == 5), table_b.update().where(table_b.c.b == 5), table_b.update().where(table_b.c.b == 5).with_dialect_options( mysql_limit=10), table_b.update().where(table_b.c.b == 5).with_dialect_options( mysql_limit=10, sqlite_foo="some value"), table_b.update().where(table_b.c.a == 5).values(a=5, b=10), table_b.update().where(table_b.c.a == 5).values(a=5, b=10, c=12), table_b.update().where(table_b.c.b == 5).values(a=5, b=10). _annotate({"nocache": True}), table_b.update().values(a=5, b=10), table_b.update().values({ "a": 5, table_b.c.b: 10 })._annotate({"nocache": True}), table_b.update().values(a=7, b=10), table_b.update().ordered_values(("a", 5), ("b", 10)), table_b.update().ordered_values(("b", 10), ("a", 5)), table_b.update().ordered_values((table_b.c.a, 5), ("b", 10)), ), lambda: ( table_b.delete(), table_b.delete().with_dialect_options(sqlite_foo="some value"), table_b.delete().where(table_b.c.a == 5), table_b.delete().where(table_b.c.b == 5), ), lambda: ( values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myothervalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 89), (2, "textG", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mynottext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), # TODO: difference in type # values( # [ # column("mykey", Integer), # column("mytext", Text), # column("myint", Integer), # ], # (1, "textA", 99), # (2, "textB", 88), # alias_name="myvalues", # ), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).prefix_with("foo"), select([table_a.c.a]).prefix_with("foo", dialect="mysql"), select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), select([table_a.c.a]).prefix_with("bar"), select([table_a.c.a]).suffix_with("bar"), ), lambda: ( select([table_a_2.c.a]), select([table_a_2_fs.c.a]), select([table_a_2_bs.c.a]), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).with_hint(None, "some hint"), select([table_a.c.a]).with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some hint"), select([table_a.c.a]).with_hint(table_a, "some hint").with_hint( None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some other hint"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="mysql"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="postgresql"), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join(table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( SelectStatementGrouping(select([table_a])), SelectStatementGrouping(select([table_b])), ), lambda: ( select([table_a.c.a]).scalar_subquery(), select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: (table_a, table_b), ] dont_compare_values_fixtures = [ lambda: ( # note the in_(...) all have different column names becuase # otherwise all IN expressions would compare as equivalent column("x").in_(random_choices(range(10), k=3)), column("y").in_( bindparam( "q", random_choices(range(10), k=random.randint(0, 7)), expanding=True, )), column("z").in_(random_choices(range(10), k=random.randint(0, 7))), column("x") == random.randint(1, 10), ) ] def _complex_fixtures(): def one(): a1 = table_a.alias() a2 = table_b_like_a.alias() stmt = (select([table_a.c.a, a1.c.b, a2.c.b]).where(table_a.c.b == a1.c.b).where( a1.c.b == a2.c.b).where(a1.c.a == 5)) return stmt def one_diff(): a1 = table_b_like_a.alias() a2 = table_a.alias() stmt = (select([table_a.c.a, a1.c.b, a2.c.b]).where(table_a.c.b == a1.c.b).where( a1.c.b == a2.c.b).where(a1.c.a == 5)) return stmt def two(): inner = one().subquery() stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( table_b.join(inner, table_b.c.b == inner.c.b)) return stmt def three(): a1 = table_a.alias() a2 = table_a.alias() ex = exists().where(table_b.c.b == a1.c.a) stmt = (select([a1.c.a, a2.c.a]).select_from( a1.join(a2, a1.c.b == a2.c.b)).where(ex)) return stmt return [one(), one_diff(), two(), three()] fixtures.append(_complex_fixtures) def _statements_w_context_options_fixtures(): return [ select([table_a])._add_context_option(opt1, True), select([table_a])._add_context_option(opt1, 5), select([table_a])._add_context_option(opt1, True)._add_context_option( opt2, True), select([table_a ])._add_context_option(opt1, True)._add_context_option(opt2, 5), select([table_a])._add_context_option(opt3, True), ] fixtures.append(_statements_w_context_options_fixtures) def _statements_w_anonymous_col_names(): def one(): c = column("q") l = c.label(None) # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d subq = select([l]).subquery() # this creates a ColumnClause as a proxy to the Label() that has # an anoymous name, so the column has one too. anon_col = subq.c[0] # then when BindParameter is created, it checks the label # and doesn't double up on the anonymous name which is uncachable return anon_col > 5 def two(): c = column("p") l = c.label(None) # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d subq = select([l]).subquery() # this creates a ColumnClause as a proxy to the Label() that has # an anoymous name, so the column has one too. anon_col = subq.c[0] # then when BindParameter is created, it checks the label # and doesn't double up on the anonymous name which is uncachable return anon_col > 5 def three(): l1, l2 = table_a.c.a.label(None), table_a.c.b.label(None) stmt = select([table_a.c.a, table_a.c.b, l1, l2]) subq = stmt.subquery() return select([subq]).where(subq.c[2] == 10) return ( one(), two(), three(), ) fixtures.append(_statements_w_anonymous_col_names)
def get_ma_values( self, span: int, start_date: datetime.date = None, end_date: datetime.date = None, ) -> List[StockPriceMA]: """ 指定期間・指定日数について、指数平滑移動平均(EMA)を計算する Parameters ---------- span: int 移動平均の計算日数(X日移動平均) start_date: datetime.date 移動平均の計算期間開始日 end_date: datetime.date 移動平均の計算期間終了日 Returns ---------- ema_dtos: List[StockPriceMA] stockprice_MAテーブルのDtoのリスト """ ema_dtos = [] alpha = 2 / (1 + span) # 平滑化定数 # WITH RECURSIVEにより再帰的に指数平滑移動平均を計算 stockprice = self.session.query( StockPrice.company_id, StockPrice.date, literal_column(str(alpha), type_=Float).label('alpha'), func.row_number().over( partition_by=StockPrice.company_id, order_by=StockPrice.date, ).label('row_number'), StockPrice.close_price, ).cte(name='all') # 非再起項(1日(行)目からスタート) ema = self.session.query( stockprice, stockprice.c.close_price.label('ema'), ).filter(stockprice.c.row_number == 1, ).cte(recursive=True, name='ema') lalias = aliased(ema, name="l") ralias = aliased(stockprice, name="r") ema = ema.union_all( # 再起項(1日ずつJOINして再帰的にEMAを計算: alpha * 当日の終値 + (1 - alpha) * 前日のEMA) self.session.query( ralias.c.company_id, ralias.c.date, ralias.c.alpha, ralias.c.row_number, ralias.c.close_price, (ralias.c.alpha * ralias.c.close_price + (1 - ralias.c.alpha) * lalias.c.ema).label('ema')).join( lalias, and_(lalias.c.company_id == ralias.c.company_id, lalias.c.row_number == ralias.c.row_number - 1))) ema_results = self.session.query(ema).all() for res in ema_results: if res.row_number >= span: dto = StockPriceMA() dto.company_id = res.company_id dto.date = res.date dto.ma_type = f'ema{span}' dto.ma_value = res.ema ema_dtos.append(dto) return ema_dtos
class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer ), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=Integer) ), text("select a, b, c from table where foo=:foo").bindparams( bindparam("foo", type_=Integer) ), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=String) ), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), column("z") > column("x"), column("x").in_([5, 7]), column("x").in_([10, 7, 8]), # note these two are mathematically equivalent but for now they # are considered to be different column("z") >= column("x"), column("x") <= column("z"), column("q").between(5, 6), column("q").between(5, 6, symmetric=True), column("q").like("somstr"), column("q").like("somstr", escape="\\"), column("q").like("somstr", escape="X"), ), lambda: ( table_a.c.a, table_a.c.a._annotate({"orm": True}), table_a.c.a._annotate({"orm": True})._annotate({"bar": False}), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("a", table_a)} ), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("b", table_a)} ), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("b", select([table_a]))} ), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(),), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group( table_a.c.a, table_a.c.b, column("q") ), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over( order_by=table_a.c.a, partition_by=table_a.c.b ), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case( whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ] ), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"), case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"), case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]) .where(table_a.c.b == 5) .where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]) .where(table_a.c.b == 5) .with_for_update(nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]) .where(table_a.c.b == 5) .correlate_except(table_b), ), lambda: ( select([table_a.c.a]).cte(), select([table_a.c.a]).cte(recursive=True), select([table_a.c.a]).cte(name="some_cte", recursive=True), select([table_a.c.a]).cte(name="some_cte"), select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), select([table_a.c.a]) .cte(name="some_cte") .union_all(select([table_a.c.a])), select([table_a.c.a]) .cte(name="some_cte") .union_all(select([table_a.c.b])), select([table_a.c.a]).lateral(), select([table_a.c.a]).lateral(name="bar"), table_a.tablesample(func.bernoulli(1)), table_a.tablesample(func.bernoulli(1), seed=func.random()), table_a.tablesample(func.bernoulli(1), seed=func.other_random()), table_a.tablesample(func.hoho(1)), table_a.tablesample(func.bernoulli(1), name="bar"), table_a.tablesample( func.bernoulli(1), name="bar", seed=func.random() ), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).prefix_with("foo"), select([table_a.c.a]).prefix_with("foo", dialect="mysql"), select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), select([table_a.c.a]).prefix_with("bar"), select([table_a.c.a]).suffix_with("bar"), ), lambda: ( select([table_a_2.c.a]), select([table_a_2_fs.c.a]), select([table_a_2_bs.c.a]), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).with_hint(None, "some hint"), select([table_a.c.a]).with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some hint"), select([table_a.c.a]) .with_hint(table_a, "some hint") .with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some other hint"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="mysql" ), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="postgresql" ), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join( table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1) ), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( SelectStatementGrouping(select([table_a])), SelectStatementGrouping(select([table_b])), ), lambda: ( select([table_a.c.a]).scalar_subquery(), select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: (table_a, table_b), ] dont_compare_values_fixtures = [ lambda: ( # note the in_(...) all have different column names becuase # otherwise all IN expressions would compare as equivalent column("x").in_(random_choices(range(10), k=3)), column("y").in_( bindparam( "q", random_choices(range(10), k=random.randint(0, 7)), expanding=True, ) ), column("z").in_(random_choices(range(10), k=random.randint(0, 7))), column("x") == random.randint(1, 10), ) ] def _complex_fixtures(): def one(): a1 = table_a.alias() a2 = table_b_like_a.alias() stmt = ( select([table_a.c.a, a1.c.b, a2.c.b]) .where(table_a.c.b == a1.c.b) .where(a1.c.b == a2.c.b) .where(a1.c.a == 5) ) return stmt def one_diff(): a1 = table_b_like_a.alias() a2 = table_a.alias() stmt = ( select([table_a.c.a, a1.c.b, a2.c.b]) .where(table_a.c.b == a1.c.b) .where(a1.c.b == a2.c.b) .where(a1.c.a == 5) ) return stmt def two(): inner = one().subquery() stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( table_b.join(inner, table_b.c.b == inner.c.b) ) return stmt def three(): a1 = table_a.alias() a2 = table_a.alias() ex = exists().where(table_b.c.b == a1.c.a) stmt = ( select([a1.c.a, a2.c.a]) .select_from(a1.join(a2, a1.c.b == a2.c.b)) .where(ex) ) return stmt return [one(), one_diff(), two(), three()] fixtures.append(_complex_fixtures)
def get_ma_values( self, span: int, start_date: datetime.date = None, end_date: datetime.date = None, ) -> List[StockPriceMA]: """ 指定期間・指定日数について、加重移動平均(WMA)を計算する Parameters ---------- span: int 移動平均の計算日数(X日移動平均) start_date: datetime.date 移動平均の計算期間開始日 end_date: datetime.date 移動平均の計算期間終了日 Returns ---------- wma_dtos: List[StockPriceMA] stockprice_MAテーブルのDtoのリスト """ wma_dtos = [] denominator = sum(range(span + 1)) # 重みの分母となる数 stockprice = self.session.query( StockPrice.company_id, StockPrice.date, StockPrice.close_price, func.row_number().over( partition_by=StockPrice.company_id, order_by=StockPrice.date, ).label('row_number'), ).cte(name='all') # 自己結合と集約の組み合わせにより、指定のspanの重み付き和(加重移動平均)を計算する lalias = aliased(stockprice, name="l") ralias = aliased(stockprice, name="r") case_statement = case([ ((lalias.c.row_number - ralias.c.row_number) == day, ralias.c.close_price * (span - day) / denominator) for day in range(span) ]) wma_results = self.session.query( lalias.c.company_id, lalias.c.date, func.sum(case_statement).label('wma')).join( ralias, and_(lalias.c.company_id == ralias.c.company_id, lalias.c.row_number - (span - 1) <= ralias.c.row_number, ralias.c.row_number <= lalias.c.row_number)).group_by( lalias.c.company_id, lalias.c.date, ).having(func.count() == span, ).order_by( lalias.c.company_id, lalias.c.date, ).all() for res in wma_results: dto = StockPriceMA() dto.company_id = res.company_id dto.date = res.date dto.ma_type = f'wma{span}' dto.ma_value = res.wma wma_dtos.append(dto) return wma_dtos
class CompareAndCopyTest(fixtures.TestBase): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(), ), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_([1, 2]), tuple_([3, 4])), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b, column("q")), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over(order_by=table_a.c.a, partition_by=table_a.c.b), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case(whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ]), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({ "wendy": "W", "jack": "J" }, value=table_a.c.a, else_="E"), case({ "wendy": "W", "jack": "J" }, value=table_a.c.b, else_="E"), case({ "wendy_w": "W", "jack": "J" }, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]).where(table_a.c.b == 5).where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update( nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]).where(table_a.c.b == 5).correlate_except( table_b), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join(table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( select([table_a.c.a]).as_scalar(), select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: ( Table("a", MetaData(), Column("q", Integer), Column("b", String)), Table("b", MetaData(), Column("q", Integer), Column("b", String)), ), ] @classmethod def setup_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? [ importlib.import_module("sqlalchemy.dialects.%s" % d) for d in dialects.__all__ if not d.startswith("_") ] def test_all_present(self): need = set( cls for cls in class_hierarchy(ClauseElement) if issubclass(cls, (ColumnElement, Selectable)) and "__init__" in cls.__dict__ and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ and "crud" not in cls.__module__ and "dialects" not in cls.__module__ # TODO: dialects? ).difference({ColumnElement, UnaryExpression}) for fixture in self.fixtures: case_a = fixture() for elem in case_a: for mro in type(elem).__mro__: need.discard(mro) is_false(bool(need), "%d Remaining classes: %r" % (len(need), need)) def test_compare(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): if a == b: is_true( case_a[a].compare(case_b[b], arbitrary_expression=True), "%r != %r" % (case_a[a], case_b[b]), ) else: is_false( case_a[a].compare(case_b[b], arbitrary_expression=True), "%r == %r" % (case_a[a], case_b[b]), ) def test_cache_key(self): def assert_params_append(assert_params): def append(param): if param._value_required_for_cache: assert_params.append(param) else: is_(param.value, None) return append for fixture in self.fixtures: case_a = fixture() case_b = fixture() for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): assert_a_params = [] assert_b_params = [] visitors.traverse_depthfirst( case_a[a], {}, {"bindparam": assert_params_append(assert_a_params)}, ) visitors.traverse_depthfirst( case_b[b], {}, {"bindparam": assert_params_append(assert_b_params)}, ) if assert_a_params: assert_raises_message( NotImplementedError, "bindparams collection argument required ", case_a[a]._cache_key, ) if assert_b_params: assert_raises_message( NotImplementedError, "bindparams collection argument required ", case_b[b]._cache_key, ) if not assert_a_params and not assert_b_params: if a == b: eq_(case_a[a]._cache_key(), case_b[b]._cache_key()) else: ne_(case_a[a]._cache_key(), case_b[b]._cache_key()) def test_cache_key_gather_bindparams(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() # in the "bindparams" case, the cache keys for bound parameters # with only different values will be the same, but the params # themselves are gathered into a collection. for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): a_params = {"bindparams": []} b_params = {"bindparams": []} if a == b: a_key = case_a[a]._cache_key(**a_params) b_key = case_b[b]._cache_key(**b_params) eq_(a_key, b_key) if a_params["bindparams"]: for a_param, b_param in zip(a_params["bindparams"], b_params["bindparams"]): assert a_param.compare(b_param) else: a_key = case_a[a]._cache_key(**a_params) b_key = case_b[b]._cache_key(**b_params) if a_key == b_key: for a_param, b_param in zip(a_params["bindparams"], b_params["bindparams"]): if not a_param.compare(b_param): break else: assert False, "Bound parameters are all the same" else: ne_(a_key, b_key) assert_a_params = [] assert_b_params = [] visitors.traverse_depthfirst( case_a[a], {}, {"bindparam": assert_a_params.append}) visitors.traverse_depthfirst( case_b[b], {}, {"bindparam": assert_b_params.append}) # note we're asserting the order of the params as well as # if there are dupes or not. ordering has to be deterministic # and matches what a traversal would provide. eq_(a_params["bindparams"], assert_a_params) eq_(b_params["bindparams"], assert_b_params) def test_compare_col_identity(self): stmt1 = (select([table_a.c.a, table_b.c.b ]).where(table_a.c.a == table_b.c.b).alias()) stmt1_c = (select([table_a.c.a, table_b.c.b ]).where(table_a.c.a == table_b.c.b).alias()) stmt2 = union(select([table_a]), select([table_b])) stmt3 = select([table_b]) equivalents = {table_a.c.a: [table_b.c.a]} is_false( stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents)) is_true( stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents)) is_true((table_a.c.a == table_b.c.b).compare( stmt1.c.a == stmt1.c.b, use_proxies=True, equivalents=equivalents, )) def test_copy_internals(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() assert case_a[0].compare(case_b[0]) clone = case_a[0]._clone() clone._copy_internals() assert clone.compare(case_b[0]) stack = [clone] seen = {clone} found_elements = False while stack: obj = stack.pop(0) items = [ subelem for key, elem in clone.__dict__.items() if key != "_is_clone_of" and elem is not None for subelem in util.to_list(elem) if (isinstance(subelem, (ColumnElement, ClauseList)) and subelem not in seen and not isinstance( subelem, Immutable) and subelem is not case_a[0]) ] stack.extend(items) seen.update(items) if obj is not clone: found_elements = True # ensure the element will not compare as true obj.compare = lambda other, **kw: False obj.__visit_name__ = "dont_match" if found_elements: assert not clone.compare(case_b[0]) assert case_a[0].compare(case_b[0])
def get_page_info( model: Type[M], cursor_column: Column, sortables: List[Column], cursor: Optional[Union[str, int]], sort: str, item_count: int = 25, surrounding_pages: int = 5, filter_criteria: Sequence[Any] = tuple(), ) -> PaginationInfo: if item_count is None: item_count = 25 sort_columns: Dict[str, Column] = {c.name: c for c in sortables} if issubclass(model, IdMixin): sort_columns["id"] = model.id sort_column_name = sort.lstrip("+-") sort_direction: Any = desc if sort.startswith("-") else asc sort_column = sort_columns[sort_column_name] if "collate" in sort_column.info: order_by = sort_direction(sort_columns[sort_column_name].collate( sort_column.info["collate"])) else: order_by = sort_direction(sort_columns[sort_column_name]) row_numbers: Any = func.row_number().over(order_by=order_by) query_filter: Any = and_(*filter_criteria) collection_size: int = ( model.query.filter(query_filter).enable_eagerloads(False).count()) if cursor is not None: # set cursor to none if no cursor is not found cursor = ( # a query with exists is more complex than this DB.session.query(cursor_column, ).filter( cursor_column == cursor ).scalar() # none or value of the cursor ) item_query: Query = model.query.filter(query_filter).order_by(order_by) if collection_size <= item_count: return PaginationInfo( collection_size=collection_size, cursor_row=0, cursor_page=1, surrounding_pages=[], last_page=PageInfo(0, 1, 0), page_items_query=item_query.limit(item_count), ) cursor_row: Union[int, Any] = 0 if cursor is not None: # always include cursor row query_filter = or_(cursor_column == cursor, and_(*filter_criteria)) item_query = model.query.filter(query_filter).order_by(order_by) cursor_row_cte: CTE = (DB.session.query( row_numbers.label("row"), cursor_column, ).filter(query_filter).from_self( column("row")).filter(cursor_column == cursor).cte("cursor_row")) cursor_row = cursor_row_cte.c.row page_rows = (DB.session.query( cursor_column, row_numbers.label("row"), (row_numbers / item_count).label("page"), (row_numbers % item_count).label("modulo"), ).filter(query_filter).order_by(column("row").asc()).cte("pages")) last_page = (DB.session.query( row_numbers.label("row"), (row_numbers / item_count).label("page"), ).filter(query_filter).order_by( column("row").desc()).limit(1).cte("last-page")) pages = ( DB.session.query(*page_rows.c).only_return_tuples(True).order_by( page_rows.c.row.asc()). filter((page_rows.c.modulo == (cursor_row % item_count)) # only return page cursors & ( # but not for all pages ( # only return the +- surrounding pages pages (page_rows.c.page >= ( (cursor_row / item_count) - surrounding_pages)) & (page_rows.c.page <= ( (cursor_row / item_count) + surrounding_pages))) | (page_rows.c.page >= (last_page.c.page - 1)) # also return last 1-2 pages )).all()) context_pages, last_page, current_cursor_row, cursor_page = digest_pages( pages, cursor, surrounding_pages, collection_size) return PaginationInfo( collection_size=collection_size, cursor_row=current_cursor_row, cursor_page=cursor_page, surrounding_pages=context_pages, last_page=last_page, page_items_query=item_query.offset(current_cursor_row).limit( item_count), )
def show(airport_name, date=None): """Show a logbook for <airport_name>.""" airport = db.session.query(Airport).filter( Airport.name == airport_name).first() if airport is None: print('Airport "{}" not found.'.format(airport_name)) return or_args = [] if date is not None: date = datetime.strptime(date, "%Y-%m-%d") (start, end) = date_to_timestamps(date) or_args = [db.between(Logbook.reftime, start, end)] # get all logbook entries and add device and airport infos logbook_query = (db.session.query( func.row_number().over(order_by=Logbook.reftime).label("row_number"), Logbook).filter(*or_args).filter( db.or_(Logbook.takeoff_airport_id == airport.id, Logbook.landing_airport_id == airport.id)).order_by( Logbook.reftime)) # ... and finally print out the logbook print("--- Logbook ({}) ---".format(airport_name)) def none_datetime_replacer(datetime_object): return "--:--:--" if datetime_object is None else datetime_object.time( ) def none_track_replacer(track_object): return "--" if track_object is None else round(track_object / 10.0) def none_timedelta_replacer(timedelta_object): return "--:--:--" if timedelta_object is None else timedelta_object def none_registration_replacer(device_object): return "[" + device_object.address + "]" if len( device_object.infos) == 0 else device_object.infos[0].registration def none_aircraft_replacer(device_object): return "(unknown)" if len( device_object.infos) == 0 else device_object.infos[0].aircraft def airport_marker(logbook_object): if logbook_object.takeoff_airport is not None and logbook_object.takeoff_airport.name is not airport.name: return "FROM: {}".format(logbook_object.takeoff_airport.name) elif logbook_object.landing_airport is not None and logbook_object.landing_airport.name is not airport.name: return "TO: {}".format(logbook_object.landing_airport.name) else: return "" def none_altitude_replacer(logbook_object): return "?" if logbook_object.max_altitude is None else "{:5d}m ({:+5d}m)".format( logbook_object.max_altitude, logbook_object.max_altitude - logbook_object.takeoff_airport.altitude) for [row_number, logbook] in logbook_query.all(): print("%3d. %10s %8s (%2s) %8s (%2s) %8s %15s %8s %17s %20s" % ( row_number, logbook.reftime.date(), none_datetime_replacer(logbook.takeoff_timestamp), none_track_replacer(logbook.takeoff_track), none_datetime_replacer(logbook.landing_timestamp), none_track_replacer(logbook.landing_track), none_timedelta_replacer(logbook.duration), none_altitude_replacer(logbook), none_registration_replacer(logbook.device), none_aircraft_replacer(logbook.device), airport_marker(logbook), ))
bindparam('y') ) ) calc = calculate.alias() print select([users]).where(users.c.id > calc.c.z) # unique_params() is used such that our calculate statement can be used twice. calc1 = calculate.alias('c1').unique_params(x=17, y=45) calc2 = calculate.alias('c2').unique_params(x=5, y=12) s = select([users]).where(users.c.id.between(calc1.c.z, calc2.c.z)) print s print s.compile().params # window function s = select([ users.c.id, func.row_number().over(order_by=users.c.name) ]) print s # The window function is not supported in MySQL, SQLite. if re.match("mysql://", db) or re.match("sqlite://", db): pass else: for row in conn.execute(s): print row ### union from sqlalchemy.sql import union # The select() clauses below are much more complicated than the ones in the # original example, because we want to support MySQL, in which case, the # ORDER BY clause does not accept forms like table-name.column-name. Instead, # a alias has to be created via label(). See more: