Esempio n. 1
0
def range_from_queryset(
        set_ops: typing.Sequence[typing.Tuple[str, pgast.BaseRelation]],
        scls: s_obj.Object, *,
        env: context.Environment) -> pgast.BaseRangeVar:
    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = pgast.SelectStmt(
            all=True,
            larg=set_ops[0][1]
        )

        for op, rarg in set_ops[1:]:
            qry.op, qry.rarg = op, rarg
            qry = pgast.SelectStmt(
                all=True,
                larg=qry
            )

        qry = qry.larg

        rvar = pgast.RangeSubselect(
            subquery=qry,
            alias=pgast.Alias(
                aliasname=env.aliases.get(scls.shortname.name)
            )
        )

    else:
        # Just one class table, so return it directly
        rvar = set_ops[0][1].from_clause[0]

    return rvar
Esempio n. 2
0
def top_output_as_value(stmt: pgast.Query, *,
                        env: context.Environment) -> pgast.Query:
    """Finalize output serialization on the top level."""

    if env.output_format == context.OutputFormat.JSON:
        # For JSON we just want to aggregate the whole thing
        # into a JSON array.
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res.name = env.aliases.get('v')

        new_val = pgast.FuncCall(name=('json_agg', ),
                                 args=[pgast.ColumnRef(name=[stmt_res.name])])

        # XXX: nullability introspection is not reliable,
        #      remove `True or` once it is.
        if True or stmt_res.val.nullable:
            new_val = pgast.CoalesceExpr(
                args=[new_val, pgast.Constant(val='[]')])

        result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)],
                                  from_clause=[subrvar])

        result.ctes = stmt.ctes
        stmt.ctes = []

        return result

    else:
        return stmt
Esempio n. 3
0
def rvar_for_rel(
        rel: pgast.BaseRelation, *,
        lateral: bool=False, colnames: typing.List[str]=[],
        env: context.Environment) -> pgast.BaseRangeVar:
    if isinstance(rel, pgast.Query):
        alias = env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
        )
    else:
        alias = env.aliases.get(rel.name)

        rvar = pgast.RangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames)
        )

    return rvar
Esempio n. 4
0
def process_link_values(
        ir_stmt, ir_expr, target_tab, tab_cols, col_data,
        dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *,
        ctx=context.CompilerContext) -> \
        typing.Tuple[pgast.CommonTableExpr, typing.List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param tab_cols:
        A sequence of columns in the table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as std::source and std::__type__.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = dbobj.rvar_for_rel(iterator_cte,
                                               lateral=True,
                                               env=subrelctx.env)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.shape_format = context.ShapeFormat.FLAT
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.compile(ir_expr, ctx=input_rel_ctx)

    input_stmt = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data = {}

    if input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    output = pathctx.get_path_value_output(input_stmt, path_id, env=ctx.env)

    if isinstance(output, pgast.TupleVar):
        for element in output.elements:
            name = element.path_id.rptr_name()
            if name is None:
                name = element.path_id[-1].name
            colname = common.edgedb_name_to_pg_name(name)
            source_data.setdefault(colname,
                                   dbobj.get_column(input_rvar, element.name))
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['std::target'] = target_ref

    if not target_is_scalar and 'std::target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['std::target'] = target_ref

    specified_cols = []
    for col in tab_cols:
        expr = col_data.get(col)
        if expr is None:
            expr = source_data.get(col)

        if expr is not None:
            row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
            specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols