async def get_list(self, info: QueryInfo, with_count=False, *, _perm=None) -> QueryResultRowList: # hook await info.from_table.on_query(info, _perm) when_complete = [] await info.from_table.on_read(info, when_complete, _perm) model = self.mapping2model[info.from_table] # 选择项 q = Query() q = q.from_(model) select_fields = [model.id] for i in info.select_for_crud: select_fields.append(getattr(self.mapping2model[i.table], i.name)) q = q.select(*select_fields) phg = self.get_placeholder_generator() # 构造条件 if info.conditions: def solve_condition(c): if isinstance(c, QueryConditions): items = list([solve_condition(x) for x in c.items]) if items: return reduce(ComplexCriterion.__and__, items) # 部分orm在实现join条件的时候拼接的语句不正确 elif isinstance(c, (QueryConditions, ConditionLogicExpr)): items = [solve_condition(x) for x in c.items] if items: if c.type == 'and': return reduce(ComplexCriterion.__and__, items) else: return reduce(ComplexCriterion.__or__, items) elif isinstance(c, ConditionExpr): field = getattr(self.mapping2model[c.column.table], c.column.name) if isinstance(c.value, RecordMappingField): real_value = getattr(self.mapping2model[c.value.table], c.value.name) else: contains_relation = c.op in ( QUERY_OP_RELATION.CONTAINS, QUERY_OP_RELATION.CONTAINS_ANY) # value = [c.value] if c.op == QUERY_OP_RELATION.CONTAINS_ANY else c.value if c.op in (QUERY_OP_RELATION.PREFIX, QUERY_OP_RELATION.IPREFIX): # TODO: 更好的安全机制,防止利用like语句 c.value = c.value.replace('%', '') c.value = c.value + '%' real_value = phg.next( c.value, contains_relation=contains_relation) if c.op == QUERY_OP_RELATION.PREFIX: cond = field.like(real_value) elif c.op == QUERY_OP_RELATION.IPREFIX: cond = field.ilike(real_value) elif c.op == QUERY_OP_RELATION.IS: cond = BasicCriterion(ArrayMatchingExt.is_, field, field.wrap_constant(real_value)) elif c.op == QUERY_OP_RELATION.IS_NOT: cond = BasicCriterion(ArrayMatchingExt.is_not, field, field.wrap_constant(real_value)) elif c.op == QUERY_OP_RELATION.CONTAINS_ANY: # && cond = BasicCriterion(ArrayMatchingExt.contains_any, field, field.wrap_constant(real_value)) else: cond = getattr(field, _sql_method_map[c.op])(real_value) return cond elif isinstance(c, NegatedExpr): return ~solve_condition(c.expr) if info.join: for ji in info.join: jtable = self.mapping2model[ji.table] where = solve_condition(ji.conditions) if ji.limit == -1: q = q.inner_join(jtable).on(where) else: q = q.inner_join(jtable).on( jtable.id == Query.from_(jtable).select( jtable.id).where(where).limit(ji.limit)) ret = solve_condition(info.conditions) if ret: q = q.where(ret) ret = QueryResultRowList() # count if with_count: bak = q._selects q._selects = [Count('1')] cursor = await self.execute_sql(q.get_sql(), phg) ret.rows_count = next(iter(cursor))[0] q._selects = bak # 一些限制 if info.order_by: order_dict = { 'default': None, 'desc': Order.desc, 'asc': Order.asc } for i in info.order_by: q = q.orderby(i.column.name, order=order_dict[i.order]) if info.limit != -1: q = q.limit(info.limit) q = q.offset(info.offset) # 查询结果 cursor = await self.execute_sql(q.get_sql(), phg) for i in cursor: it = iter(i) ret.append( QueryResultRow(next(it), list(it), info, info.from_table)) for i in when_complete: await i(ret) return ret
async def update(self, info: QueryInfo, values: ValuesToWrite, *, _perm=None) -> IDList: # hook await info.from_table.on_query(info, _perm) when_before_update, when_complete = [], [] await info.from_table.on_update(info, values, when_before_update, when_complete, _perm) model = self.mapping2model[info.from_table] tc = self._table_cache[info.from_table] qi = info.clone() qi.select = [] lst = await self.get_list(qi, _perm=_perm) id_lst = [x.id for x in lst] for i in when_before_update: await i(id_lst) if id_lst: # 选择项 phg = self.get_placeholder_generator() sql = Query().update(model) for k, v in values.items(): vflag = values.data_flag.get(k) val = phg.next(v, left_is_array=k in tc['array_fields'], left_is_json=k in tc['json_fields']) if vflag: if vflag == ValuesDataFlag.INCR: # f'{k} + {val}' sql = sql.set( k, ArithmeticExpression(Arithmetic.add, PypikaField(k), val)) elif vflag == ValuesDataFlag.DECR: # f'{k} - {val}' sql = sql.set( k, ArithmeticExpression(Arithmetic.sub, PypikaField(k), val)) elif vflag == ValuesDataFlag.ARRAY_EXTEND: # f'{k} || {val}' vexpr = ArithmeticExpression(ArithmeticExt.concat, PypikaField(k), val) sql = sql.set(k, vexpr) elif vflag == ValuesDataFlag.ARRAY_PRUNE: # TODO: 现在prune也会去重,这是不对的 # f'array(SELECT unnest({k}) EXCEPT SELECT unnest({val}))' vexpr = PostgresArrayDifference(PypikaField(k), val) sql = sql.set(k, vexpr) elif vflag == ValuesDataFlag.ARRAY_EXTEND_DISTINCT: # f'ARRAY(SELECT DISTINCT unnest({k} || {val}))' vexpr = PostgresArrayDistinct( ArithmeticExpression(ArithmeticExt.concat, PypikaField(k), val)) sql = sql.set(k, vexpr) elif vflag == ValuesDataFlag.ARRAY_PRUNE_DISTINCT: vexpr = PostgresArrayDifference(PypikaField(k), val) sql = sql.set(k, vexpr) else: sql = sql.set(k, val) # 注意:生成的SQL顺序和values顺序的对应关系 sql = sql.where(model.id.isin(phg.next(id_lst))) await self.execute_sql(sql.get_sql(), phg) for i in when_complete: await i() return id_lst