Exemple #1
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    *,
    suppress_all_output: bool = False,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )
        assert stmt_res.name is not None

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])])

    result = pgast.SelectStmt(target_list=[
        pgast.ResTarget(
            val=count_val,
            name=stmt_res.name,
        ),
    ],
                              from_clause=[
                                  subrvar,
                              ])

    if suppress_all_output:
        subrvar = pgast.RangeSubselect(
            subquery=result, alias=pgast.Alias(aliasname=env.aliases.get('q')))

        result = pgast.SelectStmt(
            target_list=[],
            from_clause=[
                subrvar,
            ],
            where_clause=pgast.NullTest(arg=pgast.ColumnRef(
                name=[subrvar.alias.aliasname, stmt_res.name], ), ),
        )

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Exemple #2
0
def rvar_for_rel(
    rel: Union[pgast.BaseRelation, pgast.CommonTableExpr],
    *,
    typeref: Optional[irast.TypeRef] = None,
    lateral: bool = False,
    colnames: Optional[List[str]] = None,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if colnames is None:
        colnames = []

    if isinstance(rel, pgast.Query):
        alias = ctx.env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
            typeref=typeref,
        )
    else:
        alias = ctx.env.aliases.get(rel.name)

        rvar = pgast.RelRangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            typeref=typeref,
        )

    return rvar
Exemple #3
0
def range_from_queryset(
        set_ops: typing.Sequence[typing.Tuple[str, pgast.BaseRelation]],
        objname: sn.Name, *,
        env: context.Environment) -> pgast.BaseRangeVar:
    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = pgast.SelectStmt(
            all=True,
            larg=set_ops[0][1]
        )

        for op, rarg in set_ops[1:]:
            qry.op, qry.rarg = op, rarg
            qry = pgast.SelectStmt(
                all=True,
                larg=qry
            )

        qry = qry.larg

        rvar = pgast.RangeSubselect(
            subquery=qry,
            alias=pgast.Alias(
                aliasname=env.aliases.get(objname.name),
            )
        )

    else:
        # Just one class table, so return it directly
        rvar = set_ops[0][1].from_clause[0]

    return rvar
Exemple #4
0
def range_from_queryset(
        set_ops: Sequence[Tuple[str, pgast.SelectStmt]], objname: sn.Name, *,
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = pgast.SelectStmt(all=True, larg=set_ops[0][1])

        for op, rarg in set_ops[1:]:
            qry.op, qry.rarg = op, rarg
            qry = pgast.SelectStmt(all=True, larg=qry)

        rvar = pgast.RangeSubselect(
            subquery=qry.larg,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(objname.name), ))

    else:
        # Just one class table, so return it directly
        from_rvar = set_ops[0][1].from_clause[0]
        assert isinstance(from_rvar, pgast.PathRangeVar)
        rvar = from_rvar

    return rvar
Exemple #5
0
def aggregate_json_output(stmt: pgast.Query, ir_set: irast.Set, *,
                          env: context.Environment) -> pgast.Query:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    new_val = pgast.FuncCall(name=_get_json_func('agg', env=env),
                             args=[pgast.ColumnRef(name=[stmt_res.name])])

    new_val = pgast.CoalesceExpr(
        args=[new_val, pgast.StringConstant(val='[]')])

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Exemple #6
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    ir_set: irast.Set,
    *,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])]),

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=count_val, )],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Exemple #7
0
def top_output_as_config_op(
        ir_set: irast.Set,
        stmt: pgast.SelectStmt, *,
        env: context.Environment) -> pgast.Query:

    assert isinstance(ir_set.expr, irast.ConfigCommand)

    if ir_set.expr.scope is qltypes.ConfigScope.SYSTEM:
        alias = env.aliases.get('cfg')
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(
                aliasname=alias,
            )
        )

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res = stmt.target_list[0] = pgast.ResTarget(
                name=env.aliases.get('v'),
                val=stmt_res.val,
            )

        result_row = pgast.RowExpr(
            args=[
                pgast.StringConstant(val='ADD'),
                pgast.StringConstant(val=str(ir_set.expr.scope)),
                pgast.StringConstant(val=ir_set.expr.name),
                pgast.ColumnRef(name=[stmt_res.name]),
            ]
        )

        array = pgast.FuncCall(
            name=('jsonb_build_array',),
            args=result_row.args,
            null_safe=True,
            ser_safe=True,
        )

        result = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=array,
                ),
            ],
            from_clause=[
                subrvar,
            ],
        )

        result.ctes = stmt.ctes
        result.argnames = stmt.argnames
        stmt.ctes = []

        return result
    else:
        raise errors.InternalServerError(
            f'CONFIGURE {ir_set.expr.scope} INSERT is not supported')
Exemple #8
0
def top_output_as_config_op(
        ir_set: irast.Set,
        stmt: pgast.Query, *,
        env: context.Environment) -> pgast.Query:

    if ir_set.expr.system:
        alias = env.aliases.get('cfg')
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(
                aliasname=alias,
            )
        )

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res = stmt.target_list[0] = pgast.ResTarget(
                name=env.aliases.get('v'),
                val=stmt_res.val,
            )

        result_row = pgast.RowExpr(
            args=[
                pgast.StringConstant(val='ADD'),
                pgast.StringConstant(
                    val='SYSTEM' if ir_set.expr.system else 'SESSION'),
                pgast.StringConstant(val=ir_set.expr.name),
                pgast.ColumnRef(name=[stmt_res.name]),
            ]
        )

        result = pgast.FuncCall(
            name=('jsonb_build_array',),
            args=result_row.args,
            null_safe=True,
            ser_safe=True,
        )

        return pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=result,
                ),
            ],
            from_clause=[
                subrvar,
            ],
        )

    else:
        raise errors.InternalServerError(
            'CONFIGURE SESSION INSERT is not supported')
Exemple #9
0
def rvar_for_rel(
        rel: pgast.BaseRelation, *,
        lateral: bool=False, colnames: typing.List[str]=[],
        env: context.Environment) -> pgast.BaseRangeVar:
    if isinstance(rel, pgast.Query):
        alias = env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
        )
    else:
        alias = env.aliases.get(rel.name)

        rvar = pgast.RangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames)
        )

    return rvar
Exemple #10
0
def range_from_queryset(
    set_ops: Sequence[Tuple[str, pgast.SelectStmt]],
    objname: sn.Name,
    *,
    path_id: Optional[irast.PathId] = None,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = set_ops[0][1]

        for op, rarg in set_ops[1:]:
            if op == 'filter':
                qry = wrap_set_op_query(qry, ctx=ctx)
                anti_join(qry, rarg, path_id, ctx=ctx)
            else:
                qry = pgast.SelectStmt(
                    op=op,
                    all=True,
                    larg=qry,
                    rarg=rarg,
                )

        rvar = pgast.RangeSubselect(
            subquery=qry,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(objname.name), ))

    else:
        # Just one class table, so return it directly
        from_rvar = set_ops[0][1].from_clause[0]
        assert isinstance(from_rvar, pgast.PathRangeVar)
        rvar = from_rvar

    return rvar
Exemple #11
0
def rvar_for_rel(
        rel: typing.Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
        lateral: bool=False, colnames: typing.List[str]=[],
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if isinstance(rel, pgast.Query):
        alias = ctx.env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
        )
    else:
        alias = ctx.env.aliases.get(rel.name)

        rvar = pgast.RelRangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames)
        )

    return rvar
Exemple #12
0
        for op, cte, cte_path_id in overlays:
            rvar = pgast.RelRangeVar(
                relation=cte,
                typeref=typeref,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry = pgast.SelectStmt(from_clause=[rvar], )

            pathctx.put_path_value_rvar(qry, cte_path_id, rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry, cte_path_id, rvar, env=env)
            pathctx.put_path_bond(qry, cte_path_id)
            pathctx.put_path_id_map(qry, path_id, cte_path_id)

            qry_rvar = pgast.RangeSubselect(
                subquery=qry,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry2 = pgast.SelectStmt(from_clause=[qry_rvar])
            pathctx.put_path_value_rvar(qry2, path_id, qry_rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry2, path_id, qry_rvar, env=env)
            pathctx.put_path_bond(qry2, path_id)

            if op == 'replace':
                op = 'union'
                set_ops = []
            set_ops.append((op, qry2))

        rvar = range_from_queryset(set_ops,
                                   typeref.name_hint,
Exemple #13
0
def process_link_values(
        ir_stmt, ir_expr, target_tab, col_data,
        dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *,
        ctx=context.CompilerContext) -> \
        typing.Tuple[pgast.CommonTableExpr, typing.List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as `source` and `__type__`.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = relctx.rvar_for_rel(iterator_cte,
                                                lateral=True,
                                                ctx=subrelctx)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.visit(ir_expr, ctx=input_rel_ctx)
            shape_tuple = None
            if ir_expr.shape:
                shape_tuple = shapecomp.compile_shape(ir_expr,
                                                      ir_expr.shape,
                                                      ctx=input_rel_ctx)

                for element in shape_tuple.elements:
                    pathctx.put_path_var_if_not_exists(input_rel_ctx.rel,
                                                       element.path_id,
                                                       element.val,
                                                       aspect='value',
                                                       env=input_rel_ctx.env)

    input_stmt = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data: typing.Dict[str, pgast.BaseExpr] = {}

    if input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    if shape_tuple is not None:
        for element in shape_tuple.elements:
            if not element.path_id.is_linkprop_path():
                continue
            colname = element.path_id.rptr_name().name
            val = pathctx.get_rvar_path_value_var(input_rvar,
                                                  element.path_id,
                                                  env=ctx.env)
            source_data.setdefault(colname, val)
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['target'] = target_ref

    if not target_is_scalar and 'target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['target'] = target_ref

    specified_cols = []
    for col, expr in collections.ChainMap(col_data, source_data).items():
        row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
        specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols
Exemple #14
0
def process_link_values(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_expr: irast.Set,
    target_tab: Tuple[str, ...],
    col_data: Mapping[str, pgast.BaseExpr],
    dml_rvar: pgast.PathRangeVar,
    sources: Iterable[pgast.BaseRangeVar],
    props_only: bool,
    target_is_scalar: bool,
    iterator_cte: Optional[pgast.CommonTableExpr],
    ctx: context.CompilerContextLevel,
) -> Tuple[pgast.CommonTableExpr, List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as `source` and `__type__`.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = relctx.rvar_for_rel(iterator_cte,
                                                lateral=True,
                                                ctx=subrelctx)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.visit(ir_expr, ctx=input_rel_ctx)
            if (isinstance(ir_expr.expr, irast.Stmt)
                    and ir_expr.expr.iterator_stmt is not None):
                # The link value is computaed by a FOR expression,
                # check if the statement is a DML statement, and if so,
                # pull the iterator scope so that link property expressions
                # have the correct context.
                inner_iterator_cte = None
                inner_iterator_path_id = ir_expr.expr.iterator_stmt.path_id
                for cte in input_rel_ctx.toplevel_stmt.ctes:
                    if cte.query.path_id == inner_iterator_path_id:
                        inner_iterator_cte = cte
                        break
                if inner_iterator_cte is not None:
                    target_rvar = pathctx.get_path_rvar(
                        input_rel,
                        ir_expr.path_id,
                        aspect='identity',
                        env=input_rel_ctx.env,
                    )

                    pathctx.put_path_rvar(
                        input_rel,
                        inner_iterator_path_id,
                        rvar=target_rvar,
                        aspect='identity',
                        env=input_rel_ctx.env,
                    )

                    inner_iterator_rvar = relctx.rvar_for_rel(
                        inner_iterator_cte, lateral=True, ctx=subrelctx)

                    relctx.include_rvar(
                        input_rel,
                        inner_iterator_rvar,
                        path_id=inner_iterator_path_id,
                        ctx=subrelctx,
                    )

                    input_rel_ctx.path_scope[inner_iterator_path_id] = (
                        input_rel)

            shape_tuple = None
            if ir_expr.shape:
                shape_tuple = shapecomp.compile_shape(
                    ir_expr,
                    [expr for expr, _ in ir_expr.shape],
                    ctx=input_rel_ctx,
                )

                for element in shape_tuple.elements:
                    pathctx.put_path_var_if_not_exists(input_rel_ctx.rel,
                                                       element.path_id,
                                                       element.val,
                                                       aspect='value',
                                                       env=input_rel_ctx.env)

    input_stmt: pgast.Query = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data: Dict[str, pgast.BaseExpr] = {}

    if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    if shape_tuple is not None:
        for element in shape_tuple.elements:
            if not element.path_id.is_linkprop_path():
                continue
            rptr_name = element.path_id.rptr_name()
            assert rptr_name is not None
            colname = rptr_name.name
            val = pathctx.get_rvar_path_value_var(input_rvar,
                                                  element.path_id,
                                                  env=ctx.env)
            source_data.setdefault(colname, val)
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['target'] = target_ref

    if not target_is_scalar and 'target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['target'] = target_ref

    specified_cols = []
    for col, expr in collections.ChainMap(col_data, source_data).items():
        row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
        specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols
Exemple #15
0
def range_for_material_objtype(
        typeref: irast.TypeRef,
        path_id: irast.PathId,
        *,
        include_overlays: bool = True,
        include_descendants: bool = True,
        dml_source: Optional[irast.MutatingStmt] = None,
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    env = ctx.env

    if typeref.material_type is not None:
        typeref = typeref.material_type

    table_schema_name, table_name = common.get_objtype_backend_name(
        typeref.id, typeref.module_id, catenate=False)

    if typeref.name_hint.module in {'cfg', 'sys'}:
        # Redirect all queries to schema tables to edgedbss
        table_schema_name = 'edgedbss'

    relation = pgast.Relation(
        schemaname=table_schema_name,
        name=table_name,
        path_id=path_id,
    )

    rvar: pgast.PathRangeVar = pgast.RelRangeVar(
        relation=relation,
        typeref=typeref,
        include_inherited=include_descendants,
        alias=pgast.Alias(aliasname=env.aliases.get(typeref.name_hint.name)))

    overlays = get_type_rel_overlays(typeref, dml_source=dml_source, ctx=ctx)
    if overlays and include_overlays:
        set_ops = []

        qry = pgast.SelectStmt()
        qry.from_clause.append(rvar)
        pathctx.put_path_value_rvar(qry, path_id, rvar, env=env)
        if path_id.is_objtype_path():
            pathctx.put_path_source_rvar(qry, path_id, rvar, env=env)
        pathctx.put_path_bond(qry, path_id)

        set_ops.append(('union', qry))

        for op, cte, cte_path_id in overlays:
            rvar = pgast.RelRangeVar(
                relation=cte,
                typeref=typeref,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry = pgast.SelectStmt(from_clause=[rvar], )

            pathctx.put_path_value_rvar(qry, cte_path_id, rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry, cte_path_id, rvar, env=env)
            pathctx.put_path_bond(qry, cte_path_id)

            qry.view_path_id_map[path_id] = cte_path_id

            qry_rvar = pgast.RangeSubselect(
                subquery=qry,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry2 = pgast.SelectStmt(from_clause=[qry_rvar])
            pathctx.put_path_value_rvar(qry2, path_id, qry_rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry2, path_id, qry_rvar, env=env)
            pathctx.put_path_bond(qry2, path_id)

            if op == 'replace':
                op = 'union'
                set_ops = []
            set_ops.append((op, qry2))

        rvar = range_from_queryset(set_ops, typeref.name_hint, ctx=ctx)

    return rvar
Exemple #16
0
def _compile_grouping_value(
        stmt: irast.GroupStmt, used_args: AbstractSet[str], *,
        ctx: context.CompilerContextLevel) -> pgast.BaseExpr:
    '''Produce the value for the grouping binding saying what is grouped on'''
    assert stmt.grouping_binding
    grouprel = ctx.rel

    # If there is only one thing grouped on, just output the hardcoded
    if len(used_args) == 1:
        return pgast.ArrayExpr(elements=[
            pgast.StringConstant(
                val=desugar_group.key_name(list(used_args)[0]))
        ])

    using = {k: stmt.using[k] for k in used_args}

    args = [
        pathctx.get_path_var(grouprel,
                             alias_set.path_id,
                             aspect='value',
                             env=ctx.env) for alias_set, _ in using.values()
    ]

    # Call grouping on each element we group on to produce a bitmask
    grouping_alias = ctx.env.aliases.get('g')
    grouping_call = pgast.FuncCall(name=('grouping', ), args=args)
    subq = pgast.SelectStmt(target_list=[
        pgast.ResTarget(name=grouping_alias, val=grouping_call),
    ])
    q = pgast.SelectStmt(from_clause=[
        pgast.RangeSubselect(
            subquery=subq, alias=pgast.Alias(aliasname=ctx.env.aliases.get()))
    ])

    grouping_ref = pgast.ColumnRef(name=(grouping_alias, ))

    # Generate a call to ARRAY[...] with a case for each grouping
    # element, then array_remove out the NULLs.
    els: List[pgast.BaseExpr] = []
    for i, name in enumerate(using):
        name = desugar_group.key_name(name)
        mask = 1 << (len(using) - i - 1)
        # (CASE (e & <mask>) WHEN 0 THEN '<name>' ELSE NULL END)

        els.append(
            pgast.CaseExpr(
                arg=pgast.Expr(kind=pgast.ExprKind.OP,
                               name='&',
                               lexpr=grouping_ref,
                               rexpr=pgast.LiteralExpr(expr=str(mask))),
                args=[
                    pgast.CaseWhen(expr=pgast.LiteralExpr(expr='0'),
                                   result=pgast.StringConstant(val=name))
                ],
                defresult=pgast.NullConstant()))

    val = pgast.FuncCall(
        name=('array_remove', ),
        args=[pgast.ArrayExpr(elements=els),
              pgast.NullConstant()])

    q.target_list.append(pgast.ResTarget(val=val))

    return q