Example #1
0
def rvar_for_rel(
    rel: Union[pgast.BaseRelation, pgast.CommonTableExpr],
    *,
    typeref: Optional[irast.TypeRef] = None,
    lateral: bool = False,
    colnames: Optional[List[str]] = None,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if colnames is None:
        colnames = []

    if isinstance(rel, pgast.Query):
        alias = ctx.env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
            typeref=typeref,
        )
    else:
        alias = ctx.env.aliases.get(rel.name)

        rvar = pgast.RelRangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            typeref=typeref,
        )

    return rvar
Example #2
0
def range_for_material_objtype(typeref: irast.TypeRef,
                               path_id: irast.PathId,
                               *,
                               include_overlays: bool = True,
                               env: context.Environment) -> pgast.BaseRangeVar:

    from . import pathctx  # XXX: fix cycle

    if typeref.material_type is not None:
        typeref = typeref.material_type

    table_schema_name, table_name = common.get_objtype_backend_name(
        typeref.id, typeref.module_id, catenate=False)

    if typeref.name_hint.module in {'schema', 'cfg', 'sys'}:
        # Redirect all queries to schema tables to edgedbss
        table_schema_name = 'edgedbss'

    relation = pgast.Relation(
        schemaname=table_schema_name,
        name=table_name,
        path_id=path_id,
    )

    rvar = pgast.RangeVar(
        relation=relation,
        alias=pgast.Alias(aliasname=env.aliases.get(typeref.name_hint.name)))

    overlays = env.rel_overlays.get(str(typeref.id))
    if overlays and include_overlays:
        set_ops = []

        qry = pgast.SelectStmt()
        qry.from_clause.append(rvar)
        pathctx.put_path_value_rvar(qry, path_id, rvar, env=env)
        pathctx.put_path_bond(qry, path_id)

        set_ops.append(('union', qry))

        for op, cte in overlays:
            rvar = pgast.RangeVar(
                relation=cte,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry = pgast.SelectStmt(from_clause=[rvar], )

            pathctx.put_path_value_rvar(qry, path_id, rvar, env=env)
            pathctx.put_path_bond(qry, path_id)

            if op == 'replace':
                op = 'union'
                set_ops = []

            set_ops.append((op, qry))

        rvar = range_from_queryset(set_ops, typeref.name_hint, env=env)

    return rvar
Example #3
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    *,
    suppress_all_output: bool = False,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )
        assert stmt_res.name is not None

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])])

    result = pgast.SelectStmt(target_list=[
        pgast.ResTarget(
            val=count_val,
            name=stmt_res.name,
        ),
    ],
                              from_clause=[
                                  subrvar,
                              ])

    if suppress_all_output:
        subrvar = pgast.RangeSubselect(
            subquery=result, alias=pgast.Alias(aliasname=env.aliases.get('q')))

        result = pgast.SelectStmt(
            target_list=[],
            from_clause=[
                subrvar,
            ],
            where_clause=pgast.NullTest(arg=pgast.ColumnRef(
                name=[subrvar.alias.aliasname, stmt_res.name], ), ),
        )

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #4
0
def range_from_queryset(
        set_ops: Sequence[Tuple[str, pgast.SelectStmt]], objname: sn.Name, *,
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = pgast.SelectStmt(all=True, larg=set_ops[0][1])

        for op, rarg in set_ops[1:]:
            qry.op, qry.rarg = op, rarg
            qry = pgast.SelectStmt(all=True, larg=qry)

        rvar = pgast.RangeSubselect(
            subquery=qry.larg,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(objname.name), ))

    else:
        # Just one class table, so return it directly
        from_rvar = set_ops[0][1].from_clause[0]
        assert isinstance(from_rvar, pgast.PathRangeVar)
        rvar = from_rvar

    return rvar
Example #5
0
def aggregate_json_output(stmt: pgast.Query, ir_set: irast.Set, *,
                          env: context.Environment) -> pgast.Query:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    new_val = pgast.FuncCall(name=_get_json_func('agg', env=env),
                             args=[pgast.ColumnRef(name=[stmt_res.name])])

    new_val = pgast.CoalesceExpr(
        args=[new_val, pgast.StringConstant(val='[]')])

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #6
0
def range_from_queryset(
        set_ops: typing.Sequence[typing.Tuple[str, pgast.BaseRelation]],
        objname: sn.Name, *,
        env: context.Environment) -> pgast.BaseRangeVar:
    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = pgast.SelectStmt(
            all=True,
            larg=set_ops[0][1]
        )

        for op, rarg in set_ops[1:]:
            qry.op, qry.rarg = op, rarg
            qry = pgast.SelectStmt(
                all=True,
                larg=qry
            )

        qry = qry.larg

        rvar = pgast.RangeSubselect(
            subquery=qry,
            alias=pgast.Alias(
                aliasname=env.aliases.get(objname.name),
            )
        )

    else:
        # Just one class table, so return it directly
        rvar = set_ops[0][1].from_clause[0]

    return rvar
Example #7
0
def wrap_script_stmt(
    stmt: pgast.SelectStmt,
    ir_set: irast.Set,
    *,
    env: context.Environment,
) -> pgast.SelectStmt:

    subrvar = pgast.RangeSubselect(
        subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw')))

    stmt_res = stmt.target_list[0]

    if stmt_res.name is None:
        stmt_res = stmt.target_list[0] = pgast.ResTarget(
            name=env.aliases.get('v'),
            val=stmt_res.val,
        )

    count_val = pgast.FuncCall(name=('count', ),
                               args=[pgast.ColumnRef(name=[stmt_res.name])]),

    result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=count_val, )],
                              from_clause=[subrvar])

    result.ctes = stmt.ctes
    result.argnames = stmt.argnames
    stmt.ctes = []

    return result
Example #8
0
def table_from_ptrref(
    ptrref: irast.PointerRef,
    *,
    include_descendants: bool = True,
    for_mutation: bool = False,
    ctx: context.CompilerContextLevel,
) -> pgast.RelRangeVar:
    """Return a Table corresponding to a given Link."""
    table_schema_name, table_name = common.get_pointer_backend_name(
        ptrref.id,
        ptrref.name.module,
        aspect=('table'
                if for_mutation or not include_descendants else 'inhview'),
        catenate=False,
    )

    if ptrref.name.module in {'cfg', 'sys'}:
        # Redirect all queries to schema tables to edgedbss
        table_schema_name = 'edgedbss'

    relation = pgast.Relation(schemaname=table_schema_name, name=table_name)

    # Pseudo pointers (tuple and type intersection) have no schema id.
    sobj_id = ptrref.id if isinstance(ptrref, irast.PointerRef) else None
    rvar = pgast.RelRangeVar(
        schema_object_id=sobj_id,
        relation=relation,
        include_inherited=include_descendants,
        alias=pgast.Alias(
            aliasname=ctx.env.aliases.get(ptrref.shortname.name)))

    return rvar
Example #9
0
def top_output_as_config_op(
        ir_set: irast.Set,
        stmt: pgast.SelectStmt, *,
        env: context.Environment) -> pgast.Query:

    assert isinstance(ir_set.expr, irast.ConfigCommand)

    if ir_set.expr.scope is qltypes.ConfigScope.SYSTEM:
        alias = env.aliases.get('cfg')
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(
                aliasname=alias,
            )
        )

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res = stmt.target_list[0] = pgast.ResTarget(
                name=env.aliases.get('v'),
                val=stmt_res.val,
            )

        result_row = pgast.RowExpr(
            args=[
                pgast.StringConstant(val='ADD'),
                pgast.StringConstant(val=str(ir_set.expr.scope)),
                pgast.StringConstant(val=ir_set.expr.name),
                pgast.ColumnRef(name=[stmt_res.name]),
            ]
        )

        array = pgast.FuncCall(
            name=('jsonb_build_array',),
            args=result_row.args,
            null_safe=True,
            ser_safe=True,
        )

        result = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=array,
                ),
            ],
            from_clause=[
                subrvar,
            ],
        )

        result.ctes = stmt.ctes
        result.argnames = stmt.argnames
        stmt.ctes = []

        return result
    else:
        raise errors.InternalServerError(
            f'CONFIGURE {ir_set.expr.scope} INSERT is not supported')
Example #10
0
def process_linkprop_update(ir_stmt: irast.MutatingStmt, ir_expr: irast.Set,
                            wrapper: pgast.Query,
                            dml_cte: pgast.CommonTableExpr, *,
                            ctx: context.CompilerContextLevel) -> None:
    """Perform link property updates to a link relation.

    :param ir_stmt:
        IR of the statement.
    :param ir_expr:
        IR of the UPDATE body element.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL UPDATE to the main relation of the Object.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_expr.rptr
    ptrref = rptr.ptrref

    if ptrref.material_ptr:
        ptrref = ptrref.material_ptr

    target_tab = relctx.range_for_ptrref(ptrref,
                                         include_overlays=False,
                                         ctx=ctx)

    dml_cte_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    cond = astutils.new_binop(
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env),
        astutils.get_column(target_tab, 'source', nullable=False),
        op='=',
    )

    targets = []
    for prop_el, shape_op in ir_expr.shape:
        assert shape_op is qlast.ShapeOp.ASSIGN
        ptrname = prop_el.rptr.ptrref.shortname
        with ctx.new() as input_rel_ctx:
            input_rel_ctx.expr_exposed = False
            input_rel = dispatch.compile(prop_el.expr, ctx=input_rel_ctx)
            targets.append(pgast.UpdateTarget(name=ptrname.name,
                                              val=input_rel))

    updstmt = pgast.UpdateStmt(relation=target_tab,
                               where_clause=cond,
                               targets=targets,
                               from_clause=[dml_cte_rvar])

    updcte = pgast.CommonTableExpr(query=updstmt,
                                   name=ctx.env.aliases.get(
                                       ptrref.shortname.name))

    toplevel.ctes.append(updcte)
Example #11
0
def cte_for_query(
        rel: pgast.Query, *,
        env: context.Environment) -> pgast.CommonTableExpr:
    return pgast.CommonTableExpr(
        query=rel,
        alias=pgast.Alias(
            aliasname=env.aliases.get(rel.name)
        )
    )
Example #12
0
def range_for_ptrref(ptrref: irast.BasePointerRef,
                     *,
                     include_overlays: bool = True,
                     only_self: bool = False,
                     env: context.Environment) -> pgast.BaseRangeVar:
    """"Return a Range subclass corresponding to a given ptr step.

    The return value may potentially be a UNION of all tables
    corresponding to a set of specialized links computed from the given
    `ptrref` taking source inheritance into account.
    """
    tgt_col = pgtypes.get_ptrref_storage_info(ptrref,
                                              resolve_type=False,
                                              link_bias=True).column_name

    cols = ['source', tgt_col]

    set_ops = []

    if only_self:
        ptrrefs = {ptrref}
    else:
        ptrrefs = {ptrref} | ptrref.descendants

    for src_ptrref in ptrrefs:
        table = table_from_ptrref(src_ptrref, env=env)

        qry = pgast.SelectStmt()
        qry.from_clause.append(table)
        qry.rptr_rvar = table

        # Make sure all property references are pulled up properly
        for colname in cols:
            selexpr = pgast.ColumnRef(name=[table.alias.aliasname, colname])
            qry.target_list.append(pgast.ResTarget(val=selexpr, name=colname))

        set_ops.append(('union', qry))

        overlays = env.rel_overlays.get(src_ptrref.shortname)
        if overlays and include_overlays:
            for op, cte in overlays:
                rvar = pgast.RangeVar(
                    relation=cte,
                    alias=pgast.Alias(aliasname=env.aliases.get(cte.name)))

                qry = pgast.SelectStmt(
                    target_list=[
                        pgast.ResTarget(val=pgast.ColumnRef(name=[col]))
                        for col in cols
                    ],
                    from_clause=[rvar],
                )
                set_ops.append((op, qry))

    rvar = range_from_queryset(set_ops, ptrref.shortname, env=env)
    return rvar
Example #13
0
def rvar_for_rel(
        rel: pgast.BaseRelation, *,
        lateral: bool=False, colnames: typing.List[str]=[],
        env: context.Environment) -> pgast.BaseRangeVar:
    if isinstance(rel, pgast.Query):
        alias = env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
        )
    else:
        alias = env.aliases.get(rel.name)

        rvar = pgast.RangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames)
        )

    return rvar
Example #14
0
def top_output_as_config_op(
        ir_set: irast.Set,
        stmt: pgast.Query, *,
        env: context.Environment) -> pgast.Query:

    if ir_set.expr.system:
        alias = env.aliases.get('cfg')
        subrvar = pgast.RangeSubselect(
            subquery=stmt,
            alias=pgast.Alias(
                aliasname=alias,
            )
        )

        stmt_res = stmt.target_list[0]

        if stmt_res.name is None:
            stmt_res = stmt.target_list[0] = pgast.ResTarget(
                name=env.aliases.get('v'),
                val=stmt_res.val,
            )

        result_row = pgast.RowExpr(
            args=[
                pgast.StringConstant(val='ADD'),
                pgast.StringConstant(
                    val='SYSTEM' if ir_set.expr.system else 'SESSION'),
                pgast.StringConstant(val=ir_set.expr.name),
                pgast.ColumnRef(name=[stmt_res.name]),
            ]
        )

        result = pgast.FuncCall(
            name=('jsonb_build_array',),
            args=result_row.args,
            null_safe=True,
            ser_safe=True,
        )

        return pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=result,
                ),
            ],
            from_clause=[
                subrvar,
            ],
        )

    else:
        raise errors.InternalServerError(
            'CONFIGURE SESSION INSERT is not supported')
Example #15
0
def new_external_rvar(
    *,
    rel_name: Tuple[str, ...],
    path_id: irast.PathId,
    outputs: Mapping[Tuple[irast.PathId, Tuple[str, ...]], str],
) -> pgast.RelRangeVar:
    """Construct a ``RangeVar`` instance given a relation name and a path id.

    Given an optionally-qualified relation name *rel_name* and a *path_id*,
    return a ``RangeVar`` instance over the specified relation that is
    then assumed to represent the *path_id* binding.

    This is useful in situations where it is necessary to "prime" the compiler
    with a list of external relations that exist in a larger SQL expression
    that _this_ expression is being embedded into.

    The *outputs* mapping optionally specifies a set of outputs in the
    resulting range var as a ``(path_id, tuple-of-aspects): attribute name``
    mapping.
    """
    if len(rel_name) == 1:
        table_name = rel_name[0]
        schema_name = None
    elif len(rel_name) == 2:
        schema_name, table_name = rel_name
    else:
        raise AssertionError(f'unexpected rvar name: {rel_name}')

    rel = pgast.Relation(
        name=table_name,
        schemaname=schema_name,
        path_id=path_id,
    )

    alias = pgast.Alias(aliasname=table_name)

    if not path_id.is_ptr_path():
        rvar = pgast.RelRangeVar(
            relation=rel, typeref=path_id.target, alias=alias)
    else:
        rvar = pgast.RelRangeVar(
            relation=rel, alias=alias)

    for (output_pid, output_aspects), colname in outputs.items():
        var = pgast.ColumnRef(name=[colname])
        for aspect in output_aspects:
            rel.path_outputs[output_pid, aspect] = var

    return rvar
Example #16
0
def wrap_dml_cte(
    ir_stmt: irast.MutatingStmt,
    dml_cte: pgast.CommonTableExpr,
    *,
    ctx: context.CompilerContextLevel,
) -> pgast.RelRangeVar:

    wrapper = ctx.rel
    dml_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('d')))
    relctx.include_rvar(wrapper, dml_rvar, ir_stmt.subject.path_id, ctx=ctx)
    pathctx.put_path_bond(wrapper, ir_stmt.subject.path_id)

    return dml_rvar
Example #17
0
def rvar_for_rel(
        rel: typing.Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
        lateral: bool=False, colnames: typing.List[str]=[],
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if isinstance(rel, pgast.Query):
        alias = ctx.env.aliases.get(rel.name or 'q')

        rvar = pgast.RangeSubselect(
            subquery=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames),
            lateral=lateral,
        )
    else:
        alias = ctx.env.aliases.get(rel.name)

        rvar = pgast.RelRangeVar(
            relation=rel,
            alias=pgast.Alias(aliasname=alias, colnames=colnames)
        )

    return rvar
Example #18
0
def table_from_ptrref(ptrref: irast.PointerRef, *,
                      env: context.Environment) -> pgast.RangeVar:
    """Return a Table corresponding to a given Link."""
    table_schema_name, table_name = common.get_pointer_backend_name(
        ptrref.id, ptrref.module_id, catenate=False)

    if ptrref.shortname.module in {'schema', 'cfg', 'sys'}:
        # Redirect all queries to schema tables to edgedbss
        table_schema_name = 'edgedbss'

    relation = pgast.Relation(schemaname=table_schema_name, name=table_name)

    rvar = pgast.RangeVar(
        relation=relation,
        alias=pgast.Alias(aliasname=env.aliases.get(ptrref.shortname.name)))

    return rvar
Example #19
0
def range_from_queryset(
    set_ops: Sequence[Tuple[str, pgast.SelectStmt]],
    objname: sn.Name,
    *,
    path_id: Optional[irast.PathId] = None,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:

    rvar: pgast.PathRangeVar

    if len(set_ops) > 1:
        # More than one class table, generate a UNION/EXCEPT clause.
        qry = set_ops[0][1]

        for op, rarg in set_ops[1:]:
            if op == 'filter':
                qry = wrap_set_op_query(qry, ctx=ctx)
                anti_join(qry, rarg, path_id, ctx=ctx)
            else:
                qry = pgast.SelectStmt(
                    op=op,
                    all=True,
                    larg=qry,
                    rarg=rarg,
                )

        rvar = pgast.RangeSubselect(
            subquery=qry,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(objname.name), ))

    else:
        # Just one class table, so return it directly
        from_rvar = set_ops[0][1].from_clause[0]
        assert isinstance(from_rvar, pgast.PathRangeVar)
        rvar = from_rvar

    return rvar
Example #20
0
            with ctx.newrel() as sctx:
                sctx.pending_type_ctes.add(typeref.id)
                sctx.pending_query = sctx.rel
                dispatch.visit(rewrite, ctx=sctx)
                type_cte = pgast.CommonTableExpr(
                    name=ctx.env.aliases.get('t'),
                    query=sctx.rel,
                    materialized=False,
                )
                ctx.type_ctes[typeref.id] = type_cte

        with ctx.subrel() as sctx:
            cte_rvar = pgast.RelRangeVar(
                relation=type_cte,
                typeref=typeref,
                alias=pgast.Alias(aliasname=env.aliases.get('t')))
            pathctx.put_path_id_map(sctx.rel, path_id, rewrite.path_id)
            include_rvar(sctx.rel, cte_rvar, rewrite.path_id, ctx=sctx)
            rvar = rvar_for_rel(sctx.rel, typeref=typeref, ctx=sctx)
    else:
        assert isinstance(typeref.name_hint, sn.QualName)

        table_schema_name, table_name = common.get_objtype_backend_name(
            typeref.id,
            typeref.name_hint.module,
            aspect=('table'
                    if for_mutation or not include_descendants else 'inhview'),
            catenate=False,
        )

        if typeref.name_hint.module in {'cfg', 'sys'}:
Example #21
0
def process_link_values(
        ir_stmt, ir_expr, target_tab, col_data,
        dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *,
        ctx=context.CompilerContext) -> \
        typing.Tuple[pgast.CommonTableExpr, typing.List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as `source` and `__type__`.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = relctx.rvar_for_rel(iterator_cte,
                                                lateral=True,
                                                ctx=subrelctx)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.visit(ir_expr, ctx=input_rel_ctx)
            shape_tuple = None
            if ir_expr.shape:
                shape_tuple = shapecomp.compile_shape(ir_expr,
                                                      ir_expr.shape,
                                                      ctx=input_rel_ctx)

                for element in shape_tuple.elements:
                    pathctx.put_path_var_if_not_exists(input_rel_ctx.rel,
                                                       element.path_id,
                                                       element.val,
                                                       aspect='value',
                                                       env=input_rel_ctx.env)

    input_stmt = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data: typing.Dict[str, pgast.BaseExpr] = {}

    if input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    if shape_tuple is not None:
        for element in shape_tuple.elements:
            if not element.path_id.is_linkprop_path():
                continue
            colname = element.path_id.rptr_name().name
            val = pathctx.get_rvar_path_value_var(input_rvar,
                                                  element.path_id,
                                                  env=ctx.env)
            source_data.setdefault(colname, val)
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['target'] = target_ref

    if not target_is_scalar and 'target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['target'] = target_ref

    specified_cols = []
    for col, expr in collections.ChainMap(col_data, source_data).items():
        row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
        specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols
Example #22
0
def process_link_update(
        *, ir_stmt: irast.MutatingStmt, ir_set: irast.Set, props_only: bool,
        is_insert: bool, wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr,
        iterator_cte: typing.Optional[pgast.CommonTableExpr],
        ctx: context.CompilerContextLevel) -> pgast.CommonTableExpr:
    """Perform updates to a link relation as part of a DML statement.

    :param ir_stmt:
        IR of the statement.
    :param ir_set:
        IR of the INSERT/UPDATE body element.
    :param props_only:
        Whether this link update only touches link properties.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL INSERT or UPDATE to the main
        relation of the Object.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_set.rptr
    ptrref = rptr.ptrref
    assert isinstance(ptrref, irast.PointerRef)
    target_is_scalar = irtyputils.is_scalar(ptrref.dir_target)
    path_id = ir_set.path_id

    # The links in the dml class shape have been derived,
    # but we must use the correct specialized link class for the
    # base material type.
    if ptrref.material_ptr is not None:
        mptrref = ptrref.material_ptr
        assert isinstance(mptrref, irast.PointerRef)
    else:
        mptrref = ptrref

    target_rvar = relctx.range_for_ptrref(mptrref,
                                          include_overlays=False,
                                          only_self=True,
                                          ctx=ctx)
    assert isinstance(target_rvar, pgast.RelRangeVar)
    assert isinstance(target_rvar.relation, pgast.Relation)
    target_alias = target_rvar.alias.aliasname

    target_tab_name = (target_rvar.relation.schemaname,
                       target_rvar.relation.name)

    dml_cte_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    col_data = {
        'ptr_item_id':
        pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)),
                       type_name=pgast.TypeName(name=('uuid', ))),
        'source':
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env)
    }

    if not is_insert:
        # Drop all previous link records for this source.
        delcte = pgast.CommonTableExpr(query=pgast.DeleteStmt(
            relation=target_rvar,
            where_clause=astutils.new_binop(
                lexpr=col_data['source'],
                op='=',
                rexpr=pgast.ColumnRef(name=[target_alias, 'source'])),
            using_clause=[dml_cte_rvar],
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(
                    name=[target_alias, pgast.Star()]))
            ]),
                                       name=ctx.env.aliases.get(hint='d'))

        pathctx.put_path_value_rvar(delcte.query,
                                    path_id.ptr_path(),
                                    target_rvar,
                                    env=ctx.env)

        # Record the effect of this removal in the relation overlay
        # context to ensure that references to the link in the result
        # of this DML statement yield the expected results.
        dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
        relctx.add_ptr_rel_overlay(ptrref,
                                   'except',
                                   delcte,
                                   dml_stmts=dml_stack,
                                   ctx=ctx)
        toplevel.ctes.append(delcte)

    # Turn the IR of the expression on the right side of :=
    # into a subquery returning records for the link table.
    data_cte, specified_cols = process_link_values(ir_stmt,
                                                   ir_set,
                                                   target_tab_name,
                                                   col_data,
                                                   dml_cte_rvar, [],
                                                   props_only,
                                                   target_is_scalar,
                                                   iterator_cte,
                                                   ctx=ctx)

    toplevel.ctes.append(data_cte)

    data_select = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]))
        ],
        from_clause=[pgast.RelRangeVar(relation=data_cte)])

    cols = [pgast.ColumnRef(name=[col]) for col in specified_cols]

    if is_insert:
        conflict_clause = None
    else:
        # Inserting rows into the link table may produce cardinality
        # constraint violations, since the INSERT into the link table
        # is executed in the snapshot where the above DELETE from
        # the link table is not visible.  Hence, we need to use
        # the ON CONFLICT clause to resolve this.
        conflict_cols = ['source', 'target', 'ptr_item_id']
        conflict_inference = []
        conflict_exc_row = []

        for col in conflict_cols:
            conflict_inference.append(pgast.ColumnRef(name=[col]))
            conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col]))

        conflict_data = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(
                    name=[data_cte.name, pgast.Star()]))
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
            where_clause=astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=conflict_inference),
                rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row),
                op='='))

        conflict_clause = pgast.OnConflictClause(
            action='update',
            infer=pgast.InferClause(index_elems=conflict_inference),
            target_list=[
                pgast.MultiAssignRef(columns=cols, source=conflict_data)
            ])

    updcte = pgast.CommonTableExpr(
        name=ctx.env.aliases.get(hint='i'),
        query=pgast.InsertStmt(
            relation=target_rvar,
            select_stmt=data_select,
            cols=cols,
            on_conflict=conflict_clause,
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()]))
            ]))

    pathctx.put_path_value_rvar(updcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this insertion in the relation overlay
    # context to ensure that references to the link in the result
    # of this DML statement yield the expected results.
    dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
    relctx.add_ptr_rel_overlay(ptrref,
                               'union',
                               updcte,
                               dml_stmts=dml_stack,
                               ctx=ctx)
    toplevel.ctes.append(updcte)

    return data_cte
Example #23
0
def init_dml_stmt(
        ir_stmt: irast.MutatingStmt, dml_stmt: pgast.DMLQuery, *,
        ctx: context.CompilerContextLevel,
        parent_ctx: context.CompilerContextLevel) \
        -> typing.Tuple[pgast.Query, pgast.CommonTableExpr,
                        pgast.PathRangeVar,
                        typing.Optional[pgast.CommonTableExpr]]:
    """Prepare the common structure of the query representing a DML stmt.

    :param ir_stmt:
        IR of the statement.
    :param dml_stmt:
        SQL DML node instance.

    :return:
        A (*wrapper*, *dml_cte*, *range_cte*) tuple, where *wrapper* the
        the wrapping SQL statement, *dml_cte* is the CTE representing the
        SQL DML operation in the main relation of the Object, and
        *range_cte* is the CTE for the subset affected by the statement.
        *range_cte* is None for INSERT statmenets.
    """
    wrapper = ctx.rel

    clauses.init_stmt(ir_stmt, ctx, parent_ctx)

    target_ir_set = ir_stmt.subject

    dml_stmt.relation = relctx.range_for_typeref(
        ir_stmt.subject.typeref,
        ir_stmt.subject.path_id,
        include_overlays=False,
        common_parent=True,
        ctx=ctx,
    )
    pathctx.put_path_value_rvar(dml_stmt,
                                target_ir_set.path_id,
                                dml_stmt.relation,
                                env=ctx.env)
    pathctx.put_path_source_rvar(dml_stmt,
                                 target_ir_set.path_id,
                                 dml_stmt.relation,
                                 env=ctx.env)
    pathctx.put_path_bond(dml_stmt, target_ir_set.path_id)

    dml_cte = pgast.CommonTableExpr(query=dml_stmt,
                                    name=ctx.env.aliases.get(hint='m'))

    range_cte = None

    if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)):
        # UPDATE and DELETE operate over a range, so generate
        # the corresponding CTE and connect it to the DML query.
        range_cte = get_dml_range(ir_stmt, dml_stmt, ctx=ctx)

        range_rvar = pgast.RelRangeVar(
            relation=range_cte,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range')))

        relctx.pull_path_namespace(target=dml_stmt, source=range_rvar, ctx=ctx)

        # Auxiliary relations are always joined via the WHERE
        # clause due to the structure of the UPDATE/DELETE SQL statements.
        dml_stmt.where_clause = astutils.new_binop(
            lexpr=pgast.ColumnRef(
                name=[dml_stmt.relation.alias.aliasname, 'id']),
            op='=',
            rexpr=pathctx.get_rvar_path_identity_var(range_rvar,
                                                     target_ir_set.path_id,
                                                     env=ctx.env))

        # UPDATE has "FROM", while DELETE has "USING".
        if isinstance(dml_stmt, pgast.UpdateStmt):
            dml_stmt.from_clause.append(range_rvar)
        elif isinstance(dml_stmt, pgast.DeleteStmt):
            dml_stmt.using_clause.append(range_rvar)

    # Due to the fact that DML statements are structured
    # as a flat list of CTEs instead of nested range vars,
    # the top level path scope must be empty.  The necessary
    # range vars will be injected explicitly in all rels that
    # need them.
    ctx.path_scope.clear()

    pathctx.put_path_value_rvar(dml_stmt,
                                ir_stmt.subject.path_id,
                                dml_stmt.relation,
                                env=ctx.env)

    pathctx.put_path_source_rvar(dml_stmt,
                                 ir_stmt.subject.path_id,
                                 dml_stmt.relation,
                                 env=ctx.env)

    dml_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=parent_ctx.env.aliases.get('d')))

    relctx.include_rvar(wrapper, dml_rvar, ir_stmt.subject.path_id, ctx=ctx)

    pathctx.put_path_bond(wrapper, ir_stmt.subject.path_id)

    return wrapper, dml_cte, dml_rvar, range_cte
Example #24
0
def process_link_update(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_set: irast.Set,
    props_only: bool,
    is_insert: bool,
    shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN,
    source_typeref: irast.TypeRef,
    wrapper: pgast.Query,
    dml_cte: pgast.CommonTableExpr,
    iterator_cte: Optional[pgast.CommonTableExpr],
    ctx: context.CompilerContextLevel,
) -> pgast.CommonTableExpr:
    """Perform updates to a link relation as part of a DML statement.

    :param ir_stmt:
        IR of the statement.
    :param ir_set:
        IR of the INSERT/UPDATE body element.
    :param props_only:
        Whether this link update only touches link properties.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL INSERT or UPDATE to the main
        relation of the Object.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_set.rptr
    ptrref = rptr.ptrref
    assert isinstance(ptrref, irast.PointerRef)
    target_is_scalar = irtyputils.is_scalar(ir_set.typeref)
    path_id = ir_set.path_id

    # The links in the dml class shape have been derived,
    # but we must use the correct specialized link class for the
    # base material type.
    if ptrref.material_ptr is not None:
        mptrref = ptrref.material_ptr
    else:
        mptrref = ptrref

    if mptrref.out_source.id != source_typeref.id:
        for descendant in mptrref.descendants:
            if descendant.out_source.id == source_typeref.id:
                mptrref = descendant
                break
        else:
            raise errors.InternalServerError(
                'missing PointerRef descriptor for source typeref')

    assert isinstance(mptrref, irast.PointerRef)

    target_rvar = relctx.range_for_ptrref(mptrref,
                                          for_mutation=True,
                                          only_self=True,
                                          ctx=ctx)
    assert isinstance(target_rvar, pgast.RelRangeVar)
    assert isinstance(target_rvar.relation, pgast.Relation)
    target_alias = target_rvar.alias.aliasname

    target_tab_name = (target_rvar.relation.schemaname,
                       target_rvar.relation.name)

    dml_cte_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    col_data = {
        'ptr_item_id':
        pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)),
                       type_name=pgast.TypeName(name=('uuid', ))),
        'source':
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env)
    }

    # Turn the IR of the expression on the right side of :=
    # into a subquery returning records for the link table.
    data_cte, specified_cols = process_link_values(
        ir_stmt=ir_stmt,
        ir_expr=ir_set,
        target_tab=target_tab_name,
        col_data=col_data,
        dml_rvar=dml_cte_rvar,
        sources=[],
        props_only=props_only,
        target_is_scalar=target_is_scalar,
        iterator_cte=iterator_cte,
        ctx=ctx,
    )

    toplevel.ctes.append(data_cte)

    delqry: Optional[pgast.DeleteStmt]

    data_select = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]), ),
        ],
        from_clause=[
            pgast.RelRangeVar(relation=data_cte),
        ],
    )

    if not is_insert and shape_op is not qlast.ShapeOp.APPEND:
        if shape_op is qlast.ShapeOp.SUBTRACT:
            data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx)

            # Drop requested link records.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=astutils.new_binop(
                        lexpr=col_data['source'],
                        op='=',
                        rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                    ),
                    op='AND',
                    rexpr=astutils.new_binop(
                        lexpr=pgast.ColumnRef(name=[target_alias, 'target'], ),
                        op='=',
                        rexpr=pgast.ColumnRef(
                            name=[data_rvar.alias.aliasname, 'target'], ),
                    ),
                ),
                using_clause=[
                    dml_cte_rvar,
                    data_rvar,
                ],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])
        else:
            # Drop all previous link records for this source.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=col_data['source'],
                    op='=',
                    rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                ),
                using_clause=[dml_cte_rvar],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])

        delcte = pgast.CommonTableExpr(
            name=ctx.env.aliases.get(hint='d'),
            query=delqry,
        )

        pathctx.put_path_value_rvar(delcte.query,
                                    path_id.ptr_path(),
                                    target_rvar,
                                    env=ctx.env)

        # Record the effect of this removal in the relation overlay
        # context to ensure that references to the link in the result
        # of this DML statement yield the expected results.
        dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
        relctx.add_ptr_rel_overlay(ptrref,
                                   'except',
                                   delcte,
                                   dml_stmts=dml_stack,
                                   ctx=ctx)
        toplevel.ctes.append(delcte)
    else:
        delqry = None

    if shape_op is qlast.ShapeOp.SUBTRACT:
        return data_cte

    cols = [pgast.ColumnRef(name=[col]) for col in specified_cols]
    conflict_cols = ['source', 'target', 'ptr_item_id']

    if is_insert:
        conflict_clause = None
    elif len(cols) == len(conflict_cols) and delqry is not None:
        # There are no link properties, so we can optimize the
        # link replacement operation by omitting the overlapping
        # link rows from deletion.
        filter_select = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=['source']), ),
                pgast.ResTarget(val=pgast.ColumnRef(name=['target']), ),
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
        )

        delqry.where_clause = astutils.extend_binop(
            delqry.where_clause,
            astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=[
                    pgast.ColumnRef(name=['source']),
                    pgast.ColumnRef(name=['target']),
                ], ),
                rexpr=pgast.SubLink(
                    type=pgast.SubLinkType.ALL,
                    expr=filter_select,
                ),
                op='!=',
            ))

        conflict_clause = pgast.OnConflictClause(
            action='nothing',
            infer=pgast.InferClause(index_elems=[
                pgast.ColumnRef(name=[col]) for col in conflict_cols
            ]),
        )
    else:
        # Inserting rows into the link table may produce cardinality
        # constraint violations, since the INSERT into the link table
        # is executed in the snapshot where the above DELETE from
        # the link table is not visible.  Hence, we need to use
        # the ON CONFLICT clause to resolve this.
        conflict_inference = []
        conflict_exc_row = []

        for col in conflict_cols:
            conflict_inference.append(pgast.ColumnRef(name=[col]))
            conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col]))

        conflict_data = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(
                    name=[data_cte.name, pgast.Star()]))
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
            where_clause=astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=conflict_inference),
                rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row),
                op='='))

        conflict_clause = pgast.OnConflictClause(
            action='update',
            infer=pgast.InferClause(index_elems=conflict_inference),
            target_list=[
                pgast.MultiAssignRef(columns=cols, source=conflict_data)
            ])

    updcte = pgast.CommonTableExpr(
        name=ctx.env.aliases.get(hint='i'),
        query=pgast.InsertStmt(
            relation=target_rvar,
            select_stmt=data_select,
            cols=cols,
            on_conflict=conflict_clause,
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()]))
            ]))

    pathctx.put_path_value_rvar(updcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this insertion in the relation overlay
    # context to ensure that references to the link in the result
    # of this DML statement yield the expected results.
    dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
    relctx.add_ptr_rel_overlay(ptrref,
                               'union',
                               updcte,
                               dml_stmts=dml_stack,
                               ctx=ctx)
    toplevel.ctes.append(updcte)

    return data_cte
Example #25
0
def range_for_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    dml_source: Optional[irast.MutatingStmt] = None,
    for_mutation: bool = False,
    only_self: bool = False,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:
    """"Return a Range subclass corresponding to a given ptr step.

    The return value may potentially be a UNION of all tables
    corresponding to a set of specialized links computed from the given
    `ptrref` taking source inheritance into account.
    """

    output_cols = ('source', 'target')

    set_ops = []

    if ptrref.union_components:
        refs = ptrref.union_components
        if only_self and len(refs) > 1:
            raise errors.InternalServerError('unexpected union link')
    else:
        refs = {ptrref}
        assert isinstance(ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        overlays = get_ptr_rel_overlays(ptrref, dml_source=dml_source, ctx=ctx)

    for src_ptrref in refs:
        assert isinstance(src_ptrref, irast.PointerRef), \
            "expected regular PointerRef"

        # Most references to inline links are dispatched to a separate
        # code path (_new_inline_pointer_rvar) by new_pointer_rvar,
        # but when we have union pointers, some might be inline.  We
        # always use the link table if it exists (because this range
        # needs to contain any link properties, for one reason.)
        ptr_info = pg_types.get_ptrref_storage_info(
            src_ptrref,
            resolve_type=False,
            link_bias=True,
        )
        if not ptr_info:
            assert ptrref.union_components
            ptr_info = pg_types.get_ptrref_storage_info(
                src_ptrref,
                resolve_type=False,
                link_bias=False,
            )

        cols = [
            'source' if ptr_info.table_type == 'link' else 'id',
            ptr_info.column_name,
        ]

        table = table_from_ptrref(
            src_ptrref,
            ptr_info,
            include_descendants=not ptrref.union_is_concrete,
            for_mutation=for_mutation,
            ctx=ctx,
        )

        qry = pgast.SelectStmt()
        qry.from_clause.append(table)

        # Make sure all property references are pulled up properly
        for colname, output_colname in zip(cols, output_cols):
            selexpr = pgast.ColumnRef(name=[table.alias.aliasname, colname])
            qry.target_list.append(
                pgast.ResTarget(val=selexpr, name=output_colname))

        set_ops.append(('union', qry))

        overlays = get_ptr_rel_overlays(src_ptrref,
                                        dml_source=dml_source,
                                        ctx=ctx)
        if overlays and not for_mutation:
            for op, cte in overlays:
                rvar = pgast.RelRangeVar(
                    relation=cte,
                    alias=pgast.Alias(aliasname=ctx.env.aliases.get(cte.name)))

                qry = pgast.SelectStmt(
                    target_list=[
                        pgast.ResTarget(val=pgast.ColumnRef(name=[col]))
                        for col in cols
                    ],
                    from_clause=[rvar],
                )
                set_ops.append((op, qry))

    return range_from_queryset(set_ops, ptrref.shortname, ctx=ctx)
Example #26
0
def range_for_material_objtype(
        typeref: irast.TypeRef,
        path_id: irast.PathId,
        *,
        include_overlays: bool = True,
        include_descendants: bool = True,
        dml_source: Optional[irast.MutatingStmt] = None,
        ctx: context.CompilerContextLevel) -> pgast.PathRangeVar:

    env = ctx.env

    if typeref.material_type is not None:
        typeref = typeref.material_type

    table_schema_name, table_name = common.get_objtype_backend_name(
        typeref.id, typeref.module_id, catenate=False)

    if typeref.name_hint.module in {'cfg', 'sys'}:
        # Redirect all queries to schema tables to edgedbss
        table_schema_name = 'edgedbss'

    relation = pgast.Relation(
        schemaname=table_schema_name,
        name=table_name,
        path_id=path_id,
    )

    rvar: pgast.PathRangeVar = pgast.RelRangeVar(
        relation=relation,
        typeref=typeref,
        include_inherited=include_descendants,
        alias=pgast.Alias(aliasname=env.aliases.get(typeref.name_hint.name)))

    overlays = get_type_rel_overlays(typeref, dml_source=dml_source, ctx=ctx)
    if overlays and include_overlays:
        set_ops = []

        qry = pgast.SelectStmt()
        qry.from_clause.append(rvar)
        pathctx.put_path_value_rvar(qry, path_id, rvar, env=env)
        if path_id.is_objtype_path():
            pathctx.put_path_source_rvar(qry, path_id, rvar, env=env)
        pathctx.put_path_bond(qry, path_id)

        set_ops.append(('union', qry))

        for op, cte, cte_path_id in overlays:
            rvar = pgast.RelRangeVar(
                relation=cte,
                typeref=typeref,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry = pgast.SelectStmt(from_clause=[rvar], )

            pathctx.put_path_value_rvar(qry, cte_path_id, rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry, cte_path_id, rvar, env=env)
            pathctx.put_path_bond(qry, cte_path_id)

            qry.view_path_id_map[path_id] = cte_path_id

            qry_rvar = pgast.RangeSubselect(
                subquery=qry,
                alias=pgast.Alias(aliasname=env.aliases.get(hint=cte.name)))

            qry2 = pgast.SelectStmt(from_clause=[qry_rvar])
            pathctx.put_path_value_rvar(qry2, path_id, qry_rvar, env=env)
            if path_id.is_objtype_path():
                pathctx.put_path_source_rvar(qry2, path_id, qry_rvar, env=env)
            pathctx.put_path_bond(qry2, path_id)

            if op == 'replace':
                op = 'union'
                set_ops = []
            set_ops.append((op, qry2))

        rvar = range_from_queryset(set_ops, typeref.name_hint, ctx=ctx)

    return rvar
Example #27
0
def range_for_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    dml_source: Optional[irast.MutatingStmt] = None,
    for_mutation: bool = False,
    only_self: bool = False,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:
    """"Return a Range subclass corresponding to a given ptr step.

    The return value may potentially be a UNION of all tables
    corresponding to a set of specialized links computed from the given
    `ptrref` taking source inheritance into account.
    """
    tgt_col = pg_types.get_ptrref_storage_info(ptrref,
                                               resolve_type=False,
                                               link_bias=True).column_name

    cols = ['source', tgt_col]

    set_ops = []

    if ptrref.union_components:
        refs = ptrref.union_components
        if only_self and len(refs) > 1:
            raise errors.InternalServerError('unexpected union link')
    else:
        refs = {ptrref}
        assert isinstance(ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        overlays = get_ptr_rel_overlays(ptrref, dml_source=dml_source, ctx=ctx)

    for src_ptrref in refs:
        assert isinstance(src_ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        table = table_from_ptrref(
            src_ptrref,
            include_descendants=not ptrref.union_is_concrete,
            for_mutation=for_mutation,
            ctx=ctx,
        )

        qry = pgast.SelectStmt()
        qry.from_clause.append(table)

        # Make sure all property references are pulled up properly
        for colname in cols:
            selexpr = pgast.ColumnRef(name=[table.alias.aliasname, colname])
            qry.target_list.append(pgast.ResTarget(val=selexpr, name=colname))

        set_ops.append(('union', qry))

        overlays = get_ptr_rel_overlays(src_ptrref,
                                        dml_source=dml_source,
                                        ctx=ctx)
        if overlays and not for_mutation:
            for op, cte in overlays:
                rvar = pgast.RelRangeVar(
                    relation=cte,
                    alias=pgast.Alias(aliasname=ctx.env.aliases.get(cte.name)))

                qry = pgast.SelectStmt(
                    target_list=[
                        pgast.ResTarget(val=pgast.ColumnRef(name=[col]))
                        for col in cols
                    ],
                    from_clause=[rvar],
                )
                set_ops.append((op, qry))

    return range_from_queryset(set_ops, ptrref.shortname, ctx=ctx)
Example #28
0
def process_link_values(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_expr: irast.Set,
    target_tab: Tuple[str, ...],
    col_data: Mapping[str, pgast.BaseExpr],
    dml_rvar: pgast.PathRangeVar,
    sources: Iterable[pgast.BaseRangeVar],
    props_only: bool,
    target_is_scalar: bool,
    iterator_cte: Optional[pgast.CommonTableExpr],
    ctx: context.CompilerContextLevel,
) -> Tuple[pgast.CommonTableExpr, List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as `source` and `__type__`.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = relctx.rvar_for_rel(iterator_cte,
                                                lateral=True,
                                                ctx=subrelctx)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.visit(ir_expr, ctx=input_rel_ctx)
            if (isinstance(ir_expr.expr, irast.Stmt)
                    and ir_expr.expr.iterator_stmt is not None):
                # The link value is computaed by a FOR expression,
                # check if the statement is a DML statement, and if so,
                # pull the iterator scope so that link property expressions
                # have the correct context.
                inner_iterator_cte = None
                inner_iterator_path_id = ir_expr.expr.iterator_stmt.path_id
                for cte in input_rel_ctx.toplevel_stmt.ctes:
                    if cte.query.path_id == inner_iterator_path_id:
                        inner_iterator_cte = cte
                        break
                if inner_iterator_cte is not None:
                    target_rvar = pathctx.get_path_rvar(
                        input_rel,
                        ir_expr.path_id,
                        aspect='identity',
                        env=input_rel_ctx.env,
                    )

                    pathctx.put_path_rvar(
                        input_rel,
                        inner_iterator_path_id,
                        rvar=target_rvar,
                        aspect='identity',
                        env=input_rel_ctx.env,
                    )

                    inner_iterator_rvar = relctx.rvar_for_rel(
                        inner_iterator_cte, lateral=True, ctx=subrelctx)

                    relctx.include_rvar(
                        input_rel,
                        inner_iterator_rvar,
                        path_id=inner_iterator_path_id,
                        ctx=subrelctx,
                    )

                    input_rel_ctx.path_scope[inner_iterator_path_id] = (
                        input_rel)

            shape_tuple = None
            if ir_expr.shape:
                shape_tuple = shapecomp.compile_shape(
                    ir_expr,
                    [expr for expr, _ in ir_expr.shape],
                    ctx=input_rel_ctx,
                )

                for element in shape_tuple.elements:
                    pathctx.put_path_var_if_not_exists(input_rel_ctx.rel,
                                                       element.path_id,
                                                       element.val,
                                                       aspect='value',
                                                       env=input_rel_ctx.env)

    input_stmt: pgast.Query = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data: Dict[str, pgast.BaseExpr] = {}

    if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    if shape_tuple is not None:
        for element in shape_tuple.elements:
            if not element.path_id.is_linkprop_path():
                continue
            rptr_name = element.path_id.rptr_name()
            assert rptr_name is not None
            colname = rptr_name.name
            val = pathctx.get_rvar_path_value_var(input_rvar,
                                                  element.path_id,
                                                  env=ctx.env)
            source_data.setdefault(colname, val)
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['target'] = target_ref

    if not target_is_scalar and 'target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['target'] = target_ref

    specified_cols = []
    for col, expr in collections.ChainMap(col_data, source_data).items():
        row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
        specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols
Example #29
0
def array_as_json_object(
    expr: pgast.BaseExpr,
    *,
    styperef: irast.TypeRef,
    env: context.Environment,
) -> pgast.BaseExpr:
    el_type = styperef.subtypes[0]

    if irtyputils.is_tuple(el_type):
        coldeflist = []
        json_args: List[pgast.BaseExpr] = []
        is_named = any(st.element_name for st in el_type.subtypes)

        for i, st in enumerate(el_type.subtypes):
            if is_named:
                colname = st.element_name
                json_args.append(pgast.StringConstant(val=st.element_name))
            else:
                colname = str(i)

            val: pgast.BaseExpr = pgast.ColumnRef(name=[colname])
            if irtyputils.is_collection(st):
                val = coll_as_json_object(val, styperef=st, env=env)

            json_args.append(val)

            if not irtyputils.is_persistent_tuple(el_type):
                # Column definition list is only allowed for functions
                # returning "record", i.e. an anonymous tuple, which
                # would not be the case for schema-persistent tuple types.
                coldeflist.append(
                    pgast.ColumnDef(
                        name=colname,
                        typename=pgast.TypeName(
                            name=pgtypes.pg_type_from_ir_typeref(st)
                        )
                    )
                )

        if is_named:
            json_func = _get_json_func('build_object', env=env)
        else:
            json_func = _get_json_func('build_array', env=env)

        return pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(
                    val=pgast.CoalesceExpr(
                        args=[
                            pgast.FuncCall(
                                name=_get_json_func('agg', env=env),
                                args=[
                                    pgast.FuncCall(
                                        name=json_func,
                                        args=json_args,
                                    )
                                ]
                            ),
                            pgast.StringConstant(val='[]'),
                        ]
                    ),
                    ser_safe=True,
                )
            ],
            from_clause=[
                pgast.RangeFunction(
                    alias=pgast.Alias(
                        aliasname=env.aliases.get('q'),
                    ),
                    is_rowsfrom=True,
                    functions=[
                        pgast.FuncCall(
                            name=('unnest',),
                            args=[expr],
                            coldeflist=coldeflist,
                        )
                    ]
                )
            ]
        )
    else:
        return pgast.FuncCall(
            name=_get_json_func('to', env=env), args=[expr],
            null_safe=True, ser_safe=True)
Example #30
0
def init_dml_stmt(
    ir_stmt: irast.MutatingStmt,
    *,
    ctx: context.CompilerContextLevel,
    parent_ctx: context.CompilerContextLevel,
) -> DMLParts:
    """Prepare the common structure of the query representing a DML stmt.

    :param ir_stmt:
        IR of the DML statement.

    :return:
        A DMLParts tuple containing a map of DML CTEs as well as the
        common range CTE for UPDATE/DELETE statements.
    """
    clauses.init_stmt(ir_stmt, ctx, parent_ctx)

    range_cte: Optional[pgast.CommonTableExpr]
    range_rvar: Optional[pgast.RelRangeVar]

    if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)):
        # UPDATE and DELETE operate over a range, so generate
        # the corresponding CTE and connect it to the DML stetements.
        range_cte = get_dml_range(ir_stmt, ctx=ctx)
        range_rvar = pgast.RelRangeVar(
            relation=range_cte,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range')))
    else:
        range_cte = None
        range_rvar = None

    top_typeref = ir_stmt.subject.typeref
    if top_typeref.material_type:
        top_typeref = top_typeref.material_type

    typerefs = [top_typeref]

    if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)):
        if top_typeref.union:
            for component in top_typeref.union:
                if component.material_type:
                    component = component.material_type

                typerefs.append(component)
                if component.descendants:
                    typerefs.extend(component.descendants)

        if top_typeref.descendants:
            typerefs.extend(top_typeref.descendants)

    dml_map = {}

    for typeref in typerefs:
        dml_cte, dml_rvar = gen_dml_cte(
            ir_stmt,
            range_rvar=range_rvar,
            typeref=typeref,
            ctx=ctx,
        )

        dml_map[typeref] = (dml_cte, dml_rvar)

    if len(dml_map) == 1:
        union_cte, union_rvar = next(iter(dml_map.values()))
    else:
        union_components = []
        for _, dml_rvar in dml_map.values():
            union_component = pgast.SelectStmt()
            relctx.include_rvar(
                union_component,
                dml_rvar,
                ir_stmt.subject.path_id,
                ctx=ctx,
            )
            union_components.append(union_component)

        qry = pgast.SelectStmt(
            all=True,
            larg=union_components[0],
        )

        for union_component in union_components[1:]:
            qry.op = 'UNION'
            qry.rarg = union_component
            qry = pgast.SelectStmt(
                all=True,
                larg=qry,
            )

        union_cte = pgast.CommonTableExpr(query=qry.larg,
                                          name=ctx.env.aliases.get(hint='ma'))

        union_rvar = relctx.rvar_for_rel(
            union_cte,
            typeref=ir_stmt.subject.typeref,
            ctx=ctx,
        )

    relctx.include_rvar(ctx.rel, union_rvar, ir_stmt.subject.path_id, ctx=ctx)
    pathctx.put_path_bond(ctx.rel, ir_stmt.subject.path_id)

    ctx.dml_stmts[ir_stmt] = union_cte

    return DMLParts(dml_ctes=dml_map, range_cte=range_cte, union_cte=union_cte)