Ejemplo n.º 1
0
def get_rvar_path_var(rvar: pgast.PathRangeVar, path_id: irast.PathId,
                      aspect: str, *,
                      env: context.Environment) -> pgast.OutputVar:
    """Return ColumnRef for a given *path_id* in a given *range var*."""

    if (path_id, aspect) in rvar.query.path_outputs:
        outvar = rvar.query.path_outputs[path_id, aspect]
    elif is_relation_rvar(rvar):
        if ((rptr := path_id.rptr()) is not None and rvar.typeref is not None
                and rvar.query.path_id != path_id
                and (not rvar.query.path_id.is_type_intersection_path()
                     or rvar.query.path_id.src_path() != path_id)):
            actual_rptr = irtyputils.find_actual_ptrref(
                rvar.typeref,
                rptr,
            )
            if actual_rptr is not None:
                ptr_si = pg_types.get_ptrref_storage_info(actual_rptr)
            else:
                ptr_si = None
        else:
            ptr_si = None

        outvar = _get_rel_path_output(rvar.query,
                                      path_id,
                                      ptr_info=ptr_si,
                                      aspect=aspect,
                                      env=env)
Ejemplo n.º 2
0
def process_update_body(
    ir_stmt: irast.MutatingStmt,
    wrapper: pgast.Query,
    update_cte: pgast.CommonTableExpr,
    typeref: irast.TypeRef,
    *,
    ctx: context.CompilerContextLevel,
) -> None:
    """Generate SQL DML CTEs from an UpdateStmt IR.

    :param ir_stmt:
        IR of the statement.
    :param wrapper:
        Top-level SQL query.
    :param update_cte:
        CTE representing the SQL UPDATE to the main relation of the Object.
    :param typeref:
        The specific TypeRef of a set being updated.
    """
    update_stmt = update_cte.query
    assert isinstance(update_stmt, pgast.UpdateStmt)

    if ctx.enclosing_cte_iterator:
        pathctx.put_path_bond(update_stmt, ctx.enclosing_cte_iterator.path_id)

    external_updates = []

    with ctx.newscope() as subctx:
        # It is necessary to process the expressions in
        # the UpdateStmt shape body in the context of the
        # UPDATE statement so that references to the current
        # values of the updated object are resolved correctly.
        subctx.parent_rel = update_stmt
        subctx.expr_exposed = False

        for shape_el, shape_op in ir_stmt.subject.shape:
            ptrref = shape_el.rptr.ptrref
            assert isinstance(ptrref, irast.PointerRef)
            actual_ptrref = irtyputils.find_actual_ptrref(typeref, ptrref)
            updvalue = shape_el.expr
            ptr_info = pg_types.get_ptrref_storage_info(actual_ptrref,
                                                        resolve_type=True,
                                                        link_bias=False)

            if ptr_info.table_type == 'ObjectType' and updvalue is not None:
                with subctx.newscope() as scopectx:
                    val: pgast.BaseExpr

                    if irtyputils.is_tuple(shape_el.typeref):
                        # When target is a tuple type, make sure
                        # the expression is compiled into a subquery
                        # returning a single column that is explicitly
                        # cast into the appropriate composite type.
                        val = relgen.set_as_subquery(
                            shape_el,
                            as_value=True,
                            explicit_cast=ptr_info.column_type,
                            ctx=scopectx,
                        )
                    else:
                        if (isinstance(updvalue, irast.MutatingStmt)
                                and updvalue in ctx.dml_stmts):
                            with scopectx.substmt() as relctx:
                                dml_cte = ctx.dml_stmts[updvalue]
                                wrap_dml_cte(updvalue, dml_cte, ctx=relctx)
                                pathctx.get_path_identity_output(
                                    relctx.rel,
                                    updvalue.subject.path_id,
                                    env=relctx.env,
                                )
                                val = relctx.rel
                        else:
                            val = dispatch.compile(updvalue, ctx=scopectx)

                        val = pgast.TypeCast(arg=val,
                                             type_name=pgast.TypeName(
                                                 name=ptr_info.column_type))

                    if shape_op is qlast.ShapeOp.SUBTRACT:
                        val = pgast.FuncCall(
                            name=('nullif', ),
                            args=[
                                pgast.ColumnRef(name=(ptr_info.column_name, )),
                                val,
                            ],
                        )

                    updtarget = pgast.UpdateTarget(
                        name=ptr_info.column_name,
                        val=val,
                    )

                    update_stmt.targets.append(updtarget)

            props_only = is_props_only_update(shape_el, ctx=subctx)

            ptr_info = pg_types.get_ptrref_storage_info(actual_ptrref,
                                                        resolve_type=False,
                                                        link_bias=True)

            if ptr_info and ptr_info.table_type == 'link':
                external_updates.append((shape_el, shape_op, props_only))

    if not update_stmt.targets:
        # No updates directly to the set target table,
        # so convert the UPDATE statement into a SELECT.
        from_clause: List[pgast.BaseRangeVar] = [update_stmt.relation]
        from_clause.extend(update_stmt.from_clause)
        update_cte.query = pgast.SelectStmt(
            ctes=update_stmt.ctes,
            target_list=update_stmt.returning_list,
            from_clause=from_clause,
            where_clause=update_stmt.where_clause,
            path_namespace=update_stmt.path_namespace,
            path_outputs=update_stmt.path_outputs,
            path_scope=update_stmt.path_scope,
            path_rvar_map=update_stmt.path_rvar_map.copy(),
            view_path_id_map=update_stmt.view_path_id_map.copy(),
            ptr_join_map=update_stmt.ptr_join_map.copy(),
        )

    toplevel = ctx.toplevel_stmt
    toplevel.ctes.append(update_cte)

    # Process necessary updates to the link tables.
    for expr, shape_op, _ in external_updates:
        process_link_update(
            ir_stmt=ir_stmt,
            ir_set=expr,
            props_only=False,
            wrapper=wrapper,
            dml_cte=update_cte,
            iterator=None,
            is_insert=False,
            shape_op=shape_op,
            source_typeref=typeref,
            ctx=ctx,
        )
Ejemplo n.º 3
0
def process_link_update(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_set: irast.Set,
    props_only: bool,
    is_insert: bool,
    shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN,
    source_typeref: irast.TypeRef,
    wrapper: pgast.Query,
    dml_cte: pgast.CommonTableExpr,
    iterator: Optional[pgast.IteratorCTE],
    ctx: context.CompilerContextLevel,
) -> pgast.CommonTableExpr:
    """Perform updates to a link relation as part of a DML statement.

    :param ir_stmt:
        IR of the statement.
    :param ir_set:
        IR of the INSERT/UPDATE body element.
    :param props_only:
        Whether this link update only touches link properties.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL INSERT or UPDATE to the main
        relation of the Object.
    :param iterator:
        IR and CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_set.rptr
    ptrref = rptr.ptrref
    assert isinstance(ptrref, irast.PointerRef)
    target_is_scalar = irtyputils.is_scalar(ir_set.typeref)
    path_id = ir_set.path_id

    # The links in the dml class shape have been derived,
    # but we must use the correct specialized link class for the
    # base material type.
    mptrref = irtyputils.find_actual_ptrref(source_typeref, ptrref)
    assert isinstance(mptrref, irast.PointerRef)

    target_rvar = relctx.range_for_ptrref(mptrref,
                                          for_mutation=True,
                                          only_self=True,
                                          ctx=ctx)
    assert isinstance(target_rvar, pgast.RelRangeVar)
    assert isinstance(target_rvar.relation, pgast.Relation)
    target_alias = target_rvar.alias.aliasname

    dml_cte_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    col_data = {
        'ptr_item_id':
        pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)),
                       type_name=pgast.TypeName(name=('uuid', ))),
        'source':
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env)
    }

    # Turn the IR of the expression on the right side of :=
    # into a subquery returning records for the link table.
    data_cte, specified_cols = process_link_values(
        ir_stmt=ir_stmt,
        ir_expr=ir_set,
        col_data=col_data,
        dml_rvar=dml_cte_rvar,
        sources=[],
        source_typeref=source_typeref,
        target_is_scalar=target_is_scalar,
        dml_cte=dml_cte,
        iterator=iterator,
        ctx=ctx,
    )

    toplevel.ctes.append(data_cte)

    delqry: Optional[pgast.DeleteStmt]

    data_select = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]), ),
        ],
        from_clause=[
            pgast.RelRangeVar(relation=data_cte),
        ],
    )

    if not is_insert and shape_op is not qlast.ShapeOp.APPEND:
        if shape_op is qlast.ShapeOp.SUBTRACT:
            data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx)

            # Drop requested link records.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=astutils.new_binop(
                        lexpr=col_data['source'],
                        op='=',
                        rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                    ),
                    op='AND',
                    rexpr=astutils.new_binop(
                        lexpr=pgast.ColumnRef(name=[target_alias, 'target'], ),
                        op='=',
                        rexpr=pgast.ColumnRef(
                            name=[data_rvar.alias.aliasname, 'target'], ),
                    ),
                ),
                using_clause=[
                    dml_cte_rvar,
                    data_rvar,
                ],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])
        else:
            # Drop all previous link records for this source.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=col_data['source'],
                    op='=',
                    rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                ),
                using_clause=[dml_cte_rvar],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])

        delcte = pgast.CommonTableExpr(
            name=ctx.env.aliases.get(hint='d'),
            query=delqry,
        )

        pathctx.put_path_value_rvar(delcte.query,
                                    path_id.ptr_path(),
                                    target_rvar,
                                    env=ctx.env)

        # Record the effect of this removal in the relation overlay
        # context to ensure that references to the link in the result
        # of this DML statement yield the expected results.
        dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
        relctx.add_ptr_rel_overlay(ptrref,
                                   'except',
                                   delcte,
                                   dml_stmts=dml_stack,
                                   ctx=ctx)
        toplevel.ctes.append(delcte)
    else:
        delqry = None

    if shape_op is qlast.ShapeOp.SUBTRACT:
        return data_cte

    cols = [pgast.ColumnRef(name=[col]) for col in specified_cols]
    conflict_cols = ['source', 'target', 'ptr_item_id']

    if is_insert:
        conflict_clause = None
    elif len(cols) == len(conflict_cols) and delqry is not None:
        # There are no link properties, so we can optimize the
        # link replacement operation by omitting the overlapping
        # link rows from deletion.
        filter_select = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=['source']), ),
                pgast.ResTarget(val=pgast.ColumnRef(name=['target']), ),
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
        )

        delqry.where_clause = astutils.extend_binop(
            delqry.where_clause,
            astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=[
                    pgast.ColumnRef(name=['source']),
                    pgast.ColumnRef(name=['target']),
                ], ),
                rexpr=pgast.SubLink(
                    type=pgast.SubLinkType.ALL,
                    expr=filter_select,
                ),
                op='!=',
            ))

        conflict_clause = pgast.OnConflictClause(
            action='nothing',
            infer=pgast.InferClause(index_elems=[
                pgast.ColumnRef(name=[col]) for col in conflict_cols
            ]),
        )
    else:
        # Inserting rows into the link table may produce cardinality
        # constraint violations, since the INSERT into the link table
        # is executed in the snapshot where the above DELETE from
        # the link table is not visible.  Hence, we need to use
        # the ON CONFLICT clause to resolve this.
        conflict_inference = []
        conflict_exc_row = []

        for col in conflict_cols:
            conflict_inference.append(pgast.ColumnRef(name=[col]))
            conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col]))

        conflict_data = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(
                    name=[data_cte.name, pgast.Star()]))
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
            where_clause=astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=conflict_inference),
                rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row),
                op='='))

        conflict_clause = pgast.OnConflictClause(
            action='update',
            infer=pgast.InferClause(index_elems=conflict_inference),
            target_list=[
                pgast.MultiAssignRef(columns=cols, source=conflict_data)
            ])

    updcte = pgast.CommonTableExpr(
        name=ctx.env.aliases.get(hint='i'),
        query=pgast.InsertStmt(
            relation=target_rvar,
            select_stmt=data_select,
            cols=cols,
            on_conflict=conflict_clause,
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()]))
            ]))

    pathctx.put_path_value_rvar(updcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this insertion in the relation overlay
    # context to ensure that references to the link in the result
    # of this DML statement yield the expected results.
    dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
    relctx.add_ptr_rel_overlay(ptrref,
                               'union',
                               updcte,
                               dml_stmts=dml_stack,
                               ctx=ctx)
    toplevel.ctes.append(updcte)

    return data_cte
Ejemplo n.º 4
0
def process_link_values(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_expr: irast.Set,
    col_data: Mapping[str, pgast.BaseExpr],
    dml_rvar: pgast.PathRangeVar,
    sources: Iterable[pgast.BaseRangeVar],
    source_typeref: irast.TypeRef,
    target_is_scalar: bool,
    dml_cte: pgast.CommonTableExpr,
    iterator: Optional[pgast.IteratorCTE],
    ctx: context.CompilerContextLevel,
) -> Tuple[pgast.CommonTableExpr, List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as `source` and `__type__`.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator:
        IR and CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    old_dml_count = len(ctx.dml_stmts)
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        subrelctx.enclosing_cte_iterator = pgast.IteratorCTE(
            path_id=ir_stmt.subject.path_id,
            cte=dml_cte,
            parent=iterator,
            is_dml_pseudo_iterator=True)
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        merge_iterator(iterator, row_query, ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.visit(ir_expr, ctx=input_rel_ctx)

            if (isinstance(ir_expr.expr, irast.Stmt)
                    and ir_expr.expr.iterator_stmt is not None):
                # The link value is computaed by a FOR expression,
                # check if the statement is a DML statement, and if so,
                # pull the iterator scope so that link property expressions
                # have the correct context.
                inner_iterator_cte = None
                inner_iterator_path_id = ir_expr.expr.iterator_stmt.path_id
                for cte in input_rel_ctx.toplevel_stmt.ctes:
                    if cte.query.path_id == inner_iterator_path_id:
                        inner_iterator_cte = cte
                        break
                if inner_iterator_cte is not None:
                    inner_iterator_rvar = relctx.rvar_for_rel(
                        inner_iterator_cte, lateral=True, ctx=subrelctx)

                    relctx.include_rvar(
                        input_rel,
                        inner_iterator_rvar,
                        path_id=inner_iterator_path_id,
                        ctx=subrelctx,
                    )

                    input_rel_ctx.path_scope[inner_iterator_path_id] = (
                        input_rel)

            shape_tuple = None
            if ir_expr.shape:
                shape_tuple = shapecomp.compile_shape(
                    ir_expr,
                    [expr for expr, _ in ir_expr.shape],
                    ctx=input_rel_ctx,
                )

                for element in shape_tuple.elements:
                    pathctx.put_path_var_if_not_exists(input_rel_ctx.rel,
                                                       element.path_id,
                                                       element.val,
                                                       aspect='value',
                                                       env=input_rel_ctx.env)

    input_stmt: pgast.Query = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    if len(ctx.dml_stmts) > old_dml_count:
        # If there were any nested inserts, we need to join them in.
        pathctx.put_rvar_path_bond(input_rvar, ir_stmt.subject.path_id)
    relctx.include_rvar(row_query,
                        input_rvar,
                        path_id=ir_stmt.subject.path_id,
                        ctx=ctx)

    source_data: Dict[str, pgast.BaseExpr] = {}

    if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    if shape_tuple is not None:
        for element in shape_tuple.elements:
            if not element.path_id.is_linkprop_path():
                continue
            val = pathctx.get_rvar_path_value_var(input_rvar,
                                                  element.path_id,
                                                  env=ctx.env)
            rptr = element.path_id.rptr()
            assert isinstance(rptr, irast.PointerRef)
            actual_rptr = irtyputils.find_actual_ptrref(source_typeref, rptr)
            ptr_info = pg_types.get_ptrref_storage_info(actual_rptr)
            source_data.setdefault(ptr_info.column_name, val)
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['target'] = target_ref

    if not target_is_scalar and 'target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['target'] = target_ref

    specified_cols = []
    for col, expr in collections.ChainMap(col_data, source_data).items():
        row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
        specified_cols.append(col)

    row_query.from_clause += list(sources)

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols