Ejemplo n.º 1
0
    async def delete(self, info: QueryInfo, *, _perm=None) -> IDList:
        model = self.mapping2model[info.from_table]
        when_before_delete, when_complete = [], []
        await info.from_table.on_delete(info, when_before_delete,
                                        when_complete, _perm)

        qi = info.clone()
        qi.select = []
        lst = await self.get_list(qi, _perm=_perm)

        # 选择项
        id_lst = [x.id for x in lst]

        for i in when_before_delete:
            await i(id_lst)

        if id_lst:
            phg = self.get_placeholder_generator()
            sql = Query().from_(model).delete().where(
                model.id.isin(phg.next(id_lst)))
            await self.execute_sql(sql.get_sql(), phg)

        for i in when_complete:
            await i()

        return id_lst
Ejemplo n.º 2
0
 async def execute_select(self, query: Query, custom_fields: Optional[list] = None) -> list:
     _, raw_results = await self.db.execute_query(query.get_sql())
     instance_list = []
     for row in raw_results:
         instance: "Model" = self.model._init_from_db(**row)
         if custom_fields:
             for field in custom_fields:
                 setattr(instance, field, row[field])
         instance_list.append(instance)
     await self._execute_prefetch_queries(instance_list)
     return instance_list
Ejemplo n.º 3
0
    def query_single_val(self, q: Query) -> Any:
        """Query single val from db

        Parameters
        ----------
        q : Query

        Returns
        -------
        Any
        """
        return self.cursor.execute(q.get_sql()).fetchval()
Ejemplo n.º 4
0
 async def execute_select(self,
                          query: Query,
                          custom_fields: Optional[list] = None) -> list:
     _, raw_results = await self.db.execute_query(query.get_sql())
     instance_list = []
     for row in raw_results:
         if self.select_related_idx:
             _, current_idx, _, _ = self.select_related_idx[0]
             dict_row = dict(row)
             keys = list(dict_row.keys())
             values = list(dict_row.values())
             instance: "Model" = self.model._init_from_db(
                 **dict(zip(keys[:current_idx], values[:current_idx])))
             instances = [instance]
             for model, index, model_name, parent_model in self.select_related_idx[
                     1:]:
                 obj = model._init_from_db(**dict(
                     zip(
                         map(
                             lambda x: x.split(".")[1],
                             keys[current_idx:current_idx + index],  # noqa
                         ),
                         values[current_idx:current_idx + index],  # noqa
                     )))
                 for ins in instances:
                     if isinstance(ins, parent_model):
                         setattr(ins, model_name, obj)
                 instances.append(obj)
                 current_idx += index
         else:
             instance = self.model._init_from_db(**row)
         if custom_fields:
             for field in custom_fields:
                 setattr(instance, field, row[field])
         instance_list.append(instance)
     await self._execute_prefetch_queries(instance_list)
     return instance_list
Ejemplo n.º 5
0
 async def execute_explain(self, query: Query) -> Any:
     sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql()))
     return (await self.db.execute_query(sql))[1]
Ejemplo n.º 6
0
    async def get_list(self,
                       info: QueryInfo,
                       with_count=False,
                       *,
                       _perm=None) -> QueryResultRowList:
        # hook
        await info.from_table.on_query(info, _perm)
        when_complete = []
        await info.from_table.on_read(info, when_complete, _perm)

        model = self.mapping2model[info.from_table]

        # 选择项
        q = Query()
        q = q.from_(model)

        select_fields = [model.id]
        for i in info.select_for_crud:
            select_fields.append(getattr(self.mapping2model[i.table], i.name))

        q = q.select(*select_fields)
        phg = self.get_placeholder_generator()

        # 构造条件
        if info.conditions:

            def solve_condition(c):
                if isinstance(c, QueryConditions):
                    items = list([solve_condition(x) for x in c.items])
                    if items:
                        return reduce(ComplexCriterion.__and__,
                                      items)  # 部分orm在实现join条件的时候拼接的语句不正确

                elif isinstance(c, (QueryConditions, ConditionLogicExpr)):
                    items = [solve_condition(x) for x in c.items]
                    if items:
                        if c.type == 'and':
                            return reduce(ComplexCriterion.__and__, items)
                        else:
                            return reduce(ComplexCriterion.__or__, items)

                elif isinstance(c, ConditionExpr):
                    field = getattr(self.mapping2model[c.column.table],
                                    c.column.name)

                    if isinstance(c.value, RecordMappingField):
                        real_value = getattr(self.mapping2model[c.value.table],
                                             c.value.name)
                    else:
                        contains_relation = c.op in (
                            QUERY_OP_RELATION.CONTAINS,
                            QUERY_OP_RELATION.CONTAINS_ANY)

                        # value = [c.value] if c.op == QUERY_OP_RELATION.CONTAINS_ANY else c.value
                        if c.op in (QUERY_OP_RELATION.PREFIX,
                                    QUERY_OP_RELATION.IPREFIX):
                            # TODO: 更好的安全机制,防止利用like语句
                            c.value = c.value.replace('%', '')
                            c.value = c.value + '%'
                        real_value = phg.next(
                            c.value, contains_relation=contains_relation)

                    if c.op == QUERY_OP_RELATION.PREFIX:
                        cond = field.like(real_value)
                    elif c.op == QUERY_OP_RELATION.IPREFIX:
                        cond = field.ilike(real_value)

                    elif c.op == QUERY_OP_RELATION.IS:
                        cond = BasicCriterion(ArrayMatchingExt.is_, field,
                                              field.wrap_constant(real_value))
                    elif c.op == QUERY_OP_RELATION.IS_NOT:
                        cond = BasicCriterion(ArrayMatchingExt.is_not, field,
                                              field.wrap_constant(real_value))

                    elif c.op == QUERY_OP_RELATION.CONTAINS_ANY:
                        # &&
                        cond = BasicCriterion(ArrayMatchingExt.contains_any,
                                              field,
                                              field.wrap_constant(real_value))
                    else:
                        cond = getattr(field,
                                       _sql_method_map[c.op])(real_value)

                    return cond

                elif isinstance(c, NegatedExpr):
                    return ~solve_condition(c.expr)

            if info.join:
                for ji in info.join:
                    jtable = self.mapping2model[ji.table]
                    where = solve_condition(ji.conditions)

                    if ji.limit == -1:
                        q = q.inner_join(jtable).on(where)
                    else:
                        q = q.inner_join(jtable).on(
                            jtable.id == Query.from_(jtable).select(
                                jtable.id).where(where).limit(ji.limit))

            ret = solve_condition(info.conditions)
            if ret:
                q = q.where(ret)

        ret = QueryResultRowList()

        # count
        if with_count:
            bak = q._selects
            q._selects = [Count('1')]
            cursor = await self.execute_sql(q.get_sql(), phg)
            ret.rows_count = next(iter(cursor))[0]
            q._selects = bak

        # 一些限制
        if info.order_by:
            order_dict = {
                'default': None,
                'desc': Order.desc,
                'asc': Order.asc
            }
            for i in info.order_by:
                q = q.orderby(i.column.name, order=order_dict[i.order])
        if info.limit != -1:
            q = q.limit(info.limit)
        q = q.offset(info.offset)

        # 查询结果
        cursor = await self.execute_sql(q.get_sql(), phg)

        for i in cursor:
            it = iter(i)
            ret.append(
                QueryResultRow(next(it), list(it), info, info.from_table))

        for i in when_complete:
            await i(ret)

        return ret
Ejemplo n.º 7
0
    async def update(self,
                     info: QueryInfo,
                     values: ValuesToWrite,
                     *,
                     _perm=None) -> IDList:
        # hook
        await info.from_table.on_query(info, _perm)
        when_before_update, when_complete = [], []
        await info.from_table.on_update(info, values, when_before_update,
                                        when_complete, _perm)

        model = self.mapping2model[info.from_table]
        tc = self._table_cache[info.from_table]
        qi = info.clone()
        qi.select = []
        lst = await self.get_list(qi, _perm=_perm)
        id_lst = [x.id for x in lst]

        for i in when_before_update:
            await i(id_lst)

        if id_lst:
            # 选择项
            phg = self.get_placeholder_generator()
            sql = Query().update(model)
            for k, v in values.items():
                vflag = values.data_flag.get(k)

                val = phg.next(v,
                               left_is_array=k in tc['array_fields'],
                               left_is_json=k in tc['json_fields'])

                if vflag:
                    if vflag == ValuesDataFlag.INCR:
                        # f'{k} + {val}'
                        sql = sql.set(
                            k,
                            ArithmeticExpression(Arithmetic.add,
                                                 PypikaField(k), val))

                    elif vflag == ValuesDataFlag.DECR:
                        # f'{k} - {val}'
                        sql = sql.set(
                            k,
                            ArithmeticExpression(Arithmetic.sub,
                                                 PypikaField(k), val))

                    elif vflag == ValuesDataFlag.ARRAY_EXTEND:
                        # f'{k} || {val}'
                        vexpr = ArithmeticExpression(ArithmeticExt.concat,
                                                     PypikaField(k), val)
                        sql = sql.set(k, vexpr)

                    elif vflag == ValuesDataFlag.ARRAY_PRUNE:
                        # TODO: 现在prune也会去重,这是不对的
                        # f'array(SELECT unnest({k}) EXCEPT SELECT unnest({val}))'
                        vexpr = PostgresArrayDifference(PypikaField(k), val)
                        sql = sql.set(k, vexpr)

                    elif vflag == ValuesDataFlag.ARRAY_EXTEND_DISTINCT:
                        # f'ARRAY(SELECT DISTINCT unnest({k} || {val}))'
                        vexpr = PostgresArrayDistinct(
                            ArithmeticExpression(ArithmeticExt.concat,
                                                 PypikaField(k), val))
                        sql = sql.set(k, vexpr)

                    elif vflag == ValuesDataFlag.ARRAY_PRUNE_DISTINCT:
                        vexpr = PostgresArrayDifference(PypikaField(k), val)
                        sql = sql.set(k, vexpr)

                else:
                    sql = sql.set(k, val)

            # 注意:生成的SQL顺序和values顺序的对应关系
            sql = sql.where(model.id.isin(phg.next(id_lst)))
            await self.execute_sql(sql.get_sql(), phg)

        for i in when_complete:
            await i()

        return id_lst
Ejemplo n.º 8
0
 def set_df_equiptype(self):
     a = T('EquipType')
     q = Query().from_(a).select(a.star)
     self.df_equiptype = pd.read_sql(sql=q.get_sql(), con=self.engine) \
         .set_index('Model', drop=False)