Example #1
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    ir_set: irast.Set,
    *,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])]),

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=count_val, )],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #2
0
def top_output_as_value(
        stmt: pgast.SelectStmt,
        ir_set: irast.Set, *,
        env: context.Environment) -> pgast.SelectStmt:
    """Finalize output serialization on the top level."""

    if (env.output_format is context.OutputFormat.JSON and
            not env.expected_cardinality_one):
        # For JSON we just want to aggregate the whole thing
        # into a JSON array.
        return aggregate_json_output(stmt, ir_set, env=env)

    elif (env.output_format is context.OutputFormat.NATIVE and
            env.explicit_top_cast is not None):

        typecast = pgast.TypeCast(
            arg=stmt.target_list[0].val,
            type_name=pgast.TypeName(
                name=pgtypes.pg_type_from_ir_typeref(
                    env.explicit_top_cast,
                    persistent_tuples=True,
                ),
            ),
        )

        stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=typecast,
        )

        return stmt

    else:
        # JSON_ELEMENTS and BINARY don't require any wrapping
        return stmt
Example #3
0
def aggregate_json_output(stmt: pgast.SelectStmt, ir_set: irast.Set, *,
                          env: context.Environment) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    new_val = pgast.CoalesceExpr(args=[
        pgast.FuncCall(name=_get_json_func('agg', env=env),
                       args=[pgast.ColumnRef(name=[stmt_res.name])]),
        pgast.StringConstant(val='[]')
    ])

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #4
0
def top_output_as_config_op(
        ir_set: irast.Set,
        stmt: pgast.SelectStmt, *,
        env: context.Environment) -> pgast.Query:

    assert isinstance(ir_set.expr, irast.ConfigCommand)

    if ir_set.expr.scope is qltypes.ConfigScope.SYSTEM:
        alias = env.aliases.get('cfg')
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(
                aliasname=alias,
            )
        )

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res = stmt.target_list[0] = pgast.ResTarget(
                name=env.aliases.get('v'),
                val=stmt_res.val,
            )

        result_row = pgast.RowExpr(
            args=[
                pgast.StringConstant(val='ADD'),
                pgast.StringConstant(val=str(ir_set.expr.scope)),
                pgast.StringConstant(val=ir_set.expr.name),
                pgast.ColumnRef(name=[stmt_res.name]),
            ]
        )

        array = pgast.FuncCall(
            name=('jsonb_build_array',),
            args=result_row.args,
            null_safe=True,
            ser_safe=True,
        )

        result = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=array,
                ),
            ],
            from_clause=[
                subrvar,
            ],
        )

        result.ctes = stmt.ctes
        result.argnames = stmt.argnames
        stmt.ctes = []

        return result
    else:
        raise errors.InternalServerError(
            f'CONFIGURE {ir_set.expr.scope} INSERT is not supported')
Example #5
0
def apply_volatility_ref(stmt: pgast.SelectStmt, *,
                         ctx: context.CompilerContextLevel) -> None:
    for ref in ctx.volatility_ref:
        # Apply the volatility reference.
        # See the comment in process_set_as_subquery().
        stmt.where_clause = astutils.extend_binop(
            stmt.where_clause, pgast.NullTest(
                arg=ref(),
                negated=True,
            ))
Example #6
0
def rel_join(query: pgast.SelectStmt, right_rvar: pgast.PathRangeVar, *,
             ctx: context.CompilerContextLevel) -> None:
    condition = None

    for path_id in right_rvar.query.path_scope:
        lref = maybe_get_path_var(query, path_id, aspect='identity', ctx=ctx)
        if lref is None:
            lref = maybe_get_path_var(query, path_id, aspect='value', ctx=ctx)
        if lref is None:
            continue

        rref = pathctx.get_rvar_path_identity_var(right_rvar,
                                                  path_id,
                                                  env=ctx.env)

        assert isinstance(lref, pgast.ColumnRef)
        assert isinstance(rref, pgast.ColumnRef)
        path_cond = astutils.join_condition(lref, rref)
        condition = astutils.extend_binop(condition, path_cond)

    if condition is None:
        join_type = 'cross'
    else:
        join_type = 'inner'

    if not query.from_clause:
        query.from_clause.append(right_rvar)
        if condition is not None:
            query.where_clause = astutils.extend_binop(query.where_clause,
                                                       condition)
    else:
        larg = query.from_clause[0]
        rarg = right_rvar

        query.from_clause[0] = pgast.JoinExpr(type=join_type,
                                              larg=larg,
                                              rarg=rarg,
                                              quals=condition)
Example #7
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    *,
    suppress_all_output: bool = False,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )
        assert stmt_res.name is not None

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])])

    result = pgast.SelectStmt(target_list=[
        pgast.ResTarget(
            val=count_val,
            name=stmt_res.name,
        ),
    ],
                              from_clause=[
                                  subrvar,
                              ])

    if suppress_all_output:
        subrvar = pgast.RangeSubselect(
            subquery=result, alias=pgast.Alias(aliasname=env.aliases.get('q')))

        result = pgast.SelectStmt(
            target_list=[],
            from_clause=[
                subrvar,
            ],
            where_clause=pgast.NullTest(arg=pgast.ColumnRef(
                name=[subrvar.alias.aliasname, stmt_res.name], ), ),
        )

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #8
0
def anti_join(
    lhs: pgast.SelectStmt,
    rhs: pgast.SelectStmt,
    path_id: Optional[irast.PathId],
    *,
    ctx: context.CompilerContextLevel,
) -> None:
    """Filter elements out of the LHS that appear on the RHS"""

    if path_id:
        # grab the identity from the LHS and do an
        # anti-join against the RHS.
        src_ref = pathctx.get_path_identity_var(lhs,
                                                path_id=path_id,
                                                env=ctx.env)
        pathctx.get_path_identity_output(rhs, path_id=path_id, env=ctx.env)
        cond_expr: pgast.BaseExpr = astutils.new_binop(src_ref, rhs, 'NOT IN')
    else:
        # No path we care about. Just check existance.
        cond_expr = pgast.SubLink(type=pgast.SubLinkType.NOT_EXISTS, expr=rhs)
    lhs.where_clause = astutils.extend_binop(lhs.where_clause, cond_expr)
Example #9
0
def semi_join(stmt: pgast.SelectStmt, ir_set: irast.Set,
              src_rvar: pgast.PathRangeVar, *,
              ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:
    """Join an IR Set using semi-join."""
    rptr = ir_set.rptr
    assert rptr is not None

    # Target set range.
    set_rvar = new_root_rvar(ir_set, ctx=ctx)

    ptrref = rptr.ptrref
    ptr_info = pg_types.get_ptrref_storage_info(ptrref,
                                                resolve_type=False,
                                                allow_missing=True)

    if ptr_info and ptr_info.table_type == 'ObjectType':
        if irtyputils.is_inbound_ptrref(ptrref):
            far_pid = ir_set.path_id.src_path()
            assert far_pid is not None
        else:
            far_pid = ir_set.path_id
    else:
        far_pid = ir_set.path_id
        # Link range.
        map_rvar = new_pointer_rvar(rptr, src_rvar=src_rvar, ctx=ctx)
        include_rvar(ctx.rel,
                     map_rvar,
                     path_id=ir_set.path_id.ptr_path(),
                     ctx=ctx)

    tgt_ref = pathctx.get_rvar_path_identity_var(set_rvar,
                                                 far_pid,
                                                 env=ctx.env)

    pathctx.get_path_identity_output(ctx.rel, far_pid, env=ctx.env)

    cond = astutils.new_binop(tgt_ref, ctx.rel, 'IN')
    stmt.where_clause = astutils.extend_binop(stmt.where_clause, cond)

    return set_rvar