예제 #1
0
def fini_toplevel(stmt: pgast.Query,
                  ctx: context.CompilerContextLevel) -> None:

    scan_check_ctes(stmt, ctx.env.check_ctes, ctx=ctx)

    # Type rewrites go first.
    if stmt.ctes is None:
        stmt.ctes = []
    stmt.ctes[:0] = list(ctx.type_ctes.values())

    stmt.argnames = argmap = ctx.argmap

    if not ctx.env.use_named_params:
        # Adding unused parameters into a CTE
        targets = []
        for param in ctx.env.query_params:
            pgparam = argmap[param.name]
            if pgparam.used:
                continue
            targets.append(
                pgast.ResTarget(val=pgast.TypeCast(
                    arg=pgast.ParamRef(number=pgparam.index),
                    type_name=pgast.TypeName(
                        name=pg_types.pg_type_from_ir_typeref(
                            param.ir_type)))))
        if targets:
            stmt.append_cte(
                pgast.CommonTableExpr(
                    name="__unused_vars",
                    query=pgast.SelectStmt(target_list=targets)))
예제 #2
0
def scan_check_ctes(
    stmt: pgast.Query,
    check_ctes: List[pgast.CommonTableExpr],
    *,
    ctx: context.CompilerContextLevel,
) -> None:
    if not check_ctes:
        return

    # Scan all of the check CTEs to enforce constraints that are
    # checked as explicit queries and not Postgres constraints or
    # triggers.

    # To make sure that Postgres can't optimize the checks away, we
    # reference them in the where clause of an UPDATE to a dummy
    # table.

    # Add a big random number, so that different queries should try to
    # access different "rows" of the table, in case that matters.
    base_int = random.randint(0, (1 << 60) - 1)
    val: pgast.BaseExpr = pgast.NumericConstant(val=str(base_int))

    for check_cte in check_ctes:
        # We want the CTE to be MATERIALIZED, because otherwise
        # Postgres might not fully evaluate all its columns when
        # scanning it.
        check_cte.materialized = True
        check = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.FuncCall(name=('count', ),
                                                   args=[pgast.Star()]), )
            ],
            from_clause=[
                relctx.rvar_for_rel(check_cte, ctx=ctx),
            ],
        )
        val = pgast.Expr(kind=pgast.ExprKind.OP,
                         name='+',
                         lexpr=val,
                         rexpr=check)

    update_query = pgast.UpdateStmt(
        targets=[
            pgast.UpdateTarget(name='flag',
                               val=pgast.BooleanConstant(val='true'))
        ],
        relation=pgast.RelRangeVar(
            relation=pgast.Relation(schemaname='edgedb', name='_dml_dummy')),
        where_clause=pgast.Expr(
            kind=pgast.ExprKind.OP,
            name='=',
            lexpr=pgast.ColumnRef(name=['id']),
            rexpr=val,
        ))
    stmt.append_cte(
        pgast.CommonTableExpr(query=update_query,
                              name=ctx.env.aliases.get(hint='check_scan')))
예제 #3
0
def put_path_id_map(
    rel: pgast.Query,
    outer_path_id: irast.PathId,
    inner_path_id: irast.PathId,
) -> None:
    inner_path_id = map_path_id(inner_path_id, rel.view_path_id_map)
    rel.view_path_id_map[outer_path_id] = inner_path_id
예제 #4
0
파일: output.py 프로젝트: xing0713/edgedb
def aggregate_json_output(stmt: pgast.Query, ir_set: irast.Set, *,
                          env: context.Environment) -> pgast.Query:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    new_val = pgast.FuncCall(name=_get_json_func('agg', env=env),
                             args=[pgast.ColumnRef(name=[stmt_res.name])])

    new_val = pgast.CoalesceExpr(
        args=[new_val, pgast.StringConstant(val='[]')])

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
예제 #5
0
파일: output.py 프로젝트: xing0713/edgedb
def top_output_as_value(stmt: pgast.Query, ir_set: irast.Set, *,
                        env: context.Environment) -> pgast.Query:
    """Finalize output serialization on the top level."""

    if (env.output_format is context.OutputFormat.JSON
            and not env.expected_cardinality_one):
        # For JSON we just want to aggregate the whole thing
        # into a JSON array.
        return aggregate_json_output(stmt, ir_set, env=env)

    elif (env.output_format is context.OutputFormat.NATIVE
          and env.explicit_top_cast is not None):

        typecast = pgast.TypeCast(
            arg=stmt.target_list[0].val,
            type_name=pgast.TypeName(name=pgtypes.pg_type_from_ir_typeref(
                env.explicit_top_cast,
                persistent_tuples=True,
            ), ),
        )

        stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=typecast,
        )

        return stmt

    else:
        return stmt
예제 #6
0
def put_path_rvar_if_not_exists(
        stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *,
        flavor: str='normal',
        aspect: str, env: context.Environment) -> None:
    if (path_id, aspect) not in stmt.get_rvar_map(flavor):
        put_path_rvar(
            stmt, path_id, rvar, aspect=aspect, flavor=flavor, env=env)
예제 #7
0
def fini_stmt(stmt: pgast.Query, ctx: context.CompilerContextLevel,
              parent_ctx: context.CompilerContextLevel) -> None:

    if stmt is ctx.toplevel_stmt:
        # Type rewrites go first.
        stmt.ctes[:0] = list(ctx.type_ctes.values())

        stmt.argnames = argmap = ctx.argmap

        if not ctx.env.use_named_params:
            # Adding unused parameters into a CTE
            targets = []
            for param in ctx.env.query_params:
                if param.name in argmap:
                    continue
                if param.name.isdecimal():
                    idx = int(param.name) + 1
                else:
                    idx = len(argmap) + 1
                argmap[param.name] = pgast.Param(
                    index=idx,
                    required=param.required,
                )
                targets.append(
                    pgast.ResTarget(val=pgast.TypeCast(
                        arg=pgast.ParamRef(number=idx),
                        type_name=pgast.TypeName(
                            name=pg_types.pg_type_from_ir_typeref(
                                param.ir_type)))))
            if targets:
                ctx.toplevel_stmt.ctes.append(
                    pgast.CommonTableExpr(
                        name="__unused_vars",
                        query=pgast.SelectStmt(target_list=targets)))
예제 #8
0
def has_rvar(
        stmt: pgast.Query, rvar: pgast.PathRangeVar, *,
        env: context.Environment) -> bool:
    return any(
        rvar in set(stmt.get_rvar_map(flavor).values())
        for flavor in ('normal', 'packed')
    )
예제 #9
0
def put_path_rvar(
        stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *,
        flavor: str='normal',
        aspect: str, env: context.Environment) -> None:
    assert isinstance(path_id, irast.PathId)
    stmt.get_rvar_map(flavor)[path_id, aspect] = rvar

    # Normally, masked paths (i.e paths that are only behind a fence below),
    # will not be exposed in a query namespace.  However, when the masked
    # path in the *main* path of a set, it must still be exposed, but no
    # further than the immediate parent query.
    try:
        query = rvar.query
    except NotImplementedError:
        pass
    else:
        if path_id in query.path_id_mask:
            stmt.path_id_mask.add(path_id)
예제 #10
0
파일: pathctx.py 프로젝트: xing0713/edgedb
def put_path_var(rel: pgast.Query,
                 path_id: irast.PathId,
                 var: pgast.Base,
                 *,
                 aspect: str,
                 force: bool = False,
                 env: context.Environment) -> None:
    if (path_id, aspect) in rel.path_namespace and not force:
        raise KeyError(f'{aspect} of {path_id} is already present in {rel}')
    rel.path_namespace[path_id, aspect] = var
예제 #11
0
def put_path_rvar(stmt: pgast.Query, path_id: irast.PathId,
                  rvar: pgast.PathRangeVar, *, aspect: str,
                  env: context.Environment) -> None:
    assert isinstance(path_id, irast.PathId)
    stmt.path_rvar_map[path_id, aspect] = rvar

    # Normally, masked paths (i.e paths that are only behind a fence below),
    # will not be exposed in a query namespace.  However, when the masked
    # path in the *main* path of a set, it must still be exposed, but no
    # further than the immediate parent query.
    if hasattr(rvar, 'query') and path_id in rvar.query.path_id_mask:
        stmt.path_id_mask.add(path_id)
예제 #12
0
def maybe_get_path_rvar(
        stmt: pgast.Query, path_id: irast.PathId, *, aspect: str,
        flavor: str='normal',
        env: context.Environment) -> Optional[pgast.PathRangeVar]:
    rvar = env.external_rvars.get((path_id, aspect))
    path_rvar_map = stmt.maybe_get_rvar_map(flavor)
    if path_rvar_map is not None:
        if rvar is None and path_rvar_map:
            rvar = path_rvar_map.get((path_id, aspect))
        if rvar is None and aspect == 'identity':
            rvar = path_rvar_map.get((path_id, 'value'))
    return rvar
예제 #13
0
def rel_join(query: pgast.Query, right_rvar: pgast.BaseRangeVar, *,
             ctx: context.CompilerContextLevel) -> None:
    condition = None

    for path_id in right_rvar.path_scope:
        lref = maybe_get_path_var(query, path_id, aspect='identity', ctx=ctx)
        if lref is None:
            lref = maybe_get_path_var(query, path_id, aspect='value', ctx=ctx)
        if lref is None:
            continue

        rref = pathctx.get_rvar_path_identity_var(right_rvar,
                                                  path_id,
                                                  env=ctx.env)

        path_cond = astutils.join_condition(lref, rref)
        condition = astutils.extend_binop(condition, path_cond)

    if condition is None:
        join_type = 'cross'
    else:
        join_type = 'inner'

    if not query.from_clause:
        query.from_clause.append(right_rvar)
        if condition is not None:
            query.where_clause = astutils.extend_binop(query.where_clause,
                                                       condition)
    else:
        larg = query.from_clause[0]
        rarg = right_rvar

        query.from_clause[0] = pgast.JoinExpr(type=join_type,
                                              larg=larg,
                                              rarg=rarg,
                                              quals=condition)
예제 #14
0
def get_path_output_or_null(
        rel: pgast.Query, path_id: irast.PathId, *,
        disable_output_fusion: bool=False,
        aspect: str, env: context.Environment) -> \
        Tuple[pgast.OutputVar, bool]:

    path_id = map_path_id(path_id, rel.view_path_id_map)

    ref = maybe_get_path_output(
        rel, path_id,
        disable_output_fusion=disable_output_fusion,
        aspect=aspect, env=env)
    if ref is not None:
        return ref, False

    alt_aspect = get_less_specific_aspect(path_id, aspect)
    if alt_aspect is not None:
        # If disable_output_fusion is true, we need to be careful
        # to not reuse an existing column
        if disable_output_fusion:
            preexisting = rel.path_outputs.pop((path_id, alt_aspect), None)
        ref = maybe_get_path_output(
            rel, path_id,
            disable_output_fusion=disable_output_fusion,
            aspect=alt_aspect, env=env)
        if disable_output_fusion:
            # Put back the path_output to whatever it was before
            if not preexisting:
                rel.path_outputs.pop((path_id, alt_aspect), None)
            else:
                rel.path_outputs[(path_id, alt_aspect)] = preexisting

        if ref is not None:
            _put_path_output_var(rel, path_id, aspect, ref, env=env)
            return ref, False

    alias = env.aliases.get('null')
    restarget = pgast.ResTarget(
        name=alias,
        val=pgast.NullConstant())

    rel.target_list.append(restarget)

    ref = pgast.ColumnRef(name=[alias], nullable=True)
    _put_path_output_var(rel, path_id, aspect, ref, env=env)

    return ref, True
예제 #15
0
def semi_join(stmt: pgast.Query, ir_set: irast.Set,
              src_rvar: pgast.BaseRangeVar, *,
              ctx: context.CompilerContextLevel) -> pgast.BaseRangeVar:
    """Join an IR Set using semi-join."""
    rptr = ir_set.rptr

    # Target set range.
    typeref = ctx.join_target_type_filter.get(ir_set, ir_set.typeref)
    set_rvar = new_root_rvar(ir_set, typeref=typeref, ctx=ctx)

    ptrref = rptr.ptrref
    ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=False)

    if ptr_info.table_type == 'ObjectType':
        if irtyputils.is_inbound_ptrref(ptrref):
            far_pid = ir_set.path_id.src_path()
        else:
            far_pid = ir_set.path_id
    else:
        far_pid = ir_set.path_id
        # Link range.
        map_rvar = new_pointer_rvar(rptr, src_rvar=src_rvar, ctx=ctx)
        include_rvar(ctx.rel,
                     map_rvar,
                     path_id=ir_set.path_id.ptr_path(),
                     ctx=ctx)

    tgt_ref = pathctx.get_rvar_path_identity_var(set_rvar,
                                                 far_pid,
                                                 env=ctx.env)

    pathctx.get_path_identity_output(ctx.rel, far_pid, env=ctx.env)

    cond = astutils.new_binop(tgt_ref, ctx.rel, 'IN')
    stmt.where_clause = astutils.extend_binop(stmt.where_clause, cond)

    return set_rvar
예제 #16
0
def fini_stmt(stmt: pgast.Query, ctx: context.CompilerContextLevel,
              parent_ctx: context.CompilerContextLevel) -> None:
    if stmt is ctx.toplevel_stmt:
        stmt.argnames = ctx.argmap