def get_rvar_path_var(rvar: pgast.PathRangeVar, path_id: irast.PathId, aspect: str, *, env: context.Environment) -> pgast.OutputVar: """Return ColumnRef for a given *path_id* in a given *range var*.""" if (path_id, aspect) in rvar.query.path_outputs: outvar = rvar.query.path_outputs[path_id, aspect] elif is_relation_rvar(rvar): if ((rptr := path_id.rptr()) is not None and rvar.typeref is not None and rvar.query.path_id != path_id and (not rvar.query.path_id.is_type_intersection_path() or rvar.query.path_id.src_path() != path_id)): actual_rptr = irtyputils.find_actual_ptrref( rvar.typeref, rptr, ) if actual_rptr is not None: ptr_si = pg_types.get_ptrref_storage_info(actual_rptr) else: ptr_si = None else: ptr_si = None outvar = _get_rel_path_output(rvar.query, path_id, ptr_info=ptr_si, aspect=aspect, env=env)
def process_update_body( ir_stmt: irast.MutatingStmt, wrapper: pgast.Query, update_cte: pgast.CommonTableExpr, typeref: irast.TypeRef, *, ctx: context.CompilerContextLevel, ) -> None: """Generate SQL DML CTEs from an UpdateStmt IR. :param ir_stmt: IR of the statement. :param wrapper: Top-level SQL query. :param update_cte: CTE representing the SQL UPDATE to the main relation of the Object. :param typeref: The specific TypeRef of a set being updated. """ update_stmt = update_cte.query assert isinstance(update_stmt, pgast.UpdateStmt) if ctx.enclosing_cte_iterator: pathctx.put_path_bond(update_stmt, ctx.enclosing_cte_iterator.path_id) external_updates = [] with ctx.newscope() as subctx: # It is necessary to process the expressions in # the UpdateStmt shape body in the context of the # UPDATE statement so that references to the current # values of the updated object are resolved correctly. subctx.parent_rel = update_stmt subctx.expr_exposed = False for shape_el, shape_op in ir_stmt.subject.shape: ptrref = shape_el.rptr.ptrref assert isinstance(ptrref, irast.PointerRef) actual_ptrref = irtyputils.find_actual_ptrref(typeref, ptrref) updvalue = shape_el.expr ptr_info = pg_types.get_ptrref_storage_info(actual_ptrref, resolve_type=True, link_bias=False) if ptr_info.table_type == 'ObjectType' and updvalue is not None: with subctx.newscope() as scopectx: val: pgast.BaseExpr if irtyputils.is_tuple(shape_el.typeref): # When target is a tuple type, make sure # the expression is compiled into a subquery # returning a single column that is explicitly # cast into the appropriate composite type. val = relgen.set_as_subquery( shape_el, as_value=True, explicit_cast=ptr_info.column_type, ctx=scopectx, ) else: if (isinstance(updvalue, irast.MutatingStmt) and updvalue in ctx.dml_stmts): with scopectx.substmt() as relctx: dml_cte = ctx.dml_stmts[updvalue] wrap_dml_cte(updvalue, dml_cte, ctx=relctx) pathctx.get_path_identity_output( relctx.rel, updvalue.subject.path_id, env=relctx.env, ) val = relctx.rel else: val = dispatch.compile(updvalue, ctx=scopectx) val = pgast.TypeCast(arg=val, type_name=pgast.TypeName( name=ptr_info.column_type)) if shape_op is qlast.ShapeOp.SUBTRACT: val = pgast.FuncCall( name=('nullif', ), args=[ pgast.ColumnRef(name=(ptr_info.column_name, )), val, ], ) updtarget = pgast.UpdateTarget( name=ptr_info.column_name, val=val, ) update_stmt.targets.append(updtarget) props_only = is_props_only_update(shape_el, ctx=subctx) ptr_info = pg_types.get_ptrref_storage_info(actual_ptrref, resolve_type=False, link_bias=True) if ptr_info and ptr_info.table_type == 'link': external_updates.append((shape_el, shape_op, props_only)) if not update_stmt.targets: # No updates directly to the set target table, # so convert the UPDATE statement into a SELECT. from_clause: List[pgast.BaseRangeVar] = [update_stmt.relation] from_clause.extend(update_stmt.from_clause) update_cte.query = pgast.SelectStmt( ctes=update_stmt.ctes, target_list=update_stmt.returning_list, from_clause=from_clause, where_clause=update_stmt.where_clause, path_namespace=update_stmt.path_namespace, path_outputs=update_stmt.path_outputs, path_scope=update_stmt.path_scope, path_rvar_map=update_stmt.path_rvar_map.copy(), view_path_id_map=update_stmt.view_path_id_map.copy(), ptr_join_map=update_stmt.ptr_join_map.copy(), ) toplevel = ctx.toplevel_stmt toplevel.ctes.append(update_cte) # Process necessary updates to the link tables. for expr, shape_op, _ in external_updates: process_link_update( ir_stmt=ir_stmt, ir_set=expr, props_only=False, wrapper=wrapper, dml_cte=update_cte, iterator=None, is_insert=False, shape_op=shape_op, source_typeref=typeref, ctx=ctx, )
def process_link_update( *, ir_stmt: irast.MutatingStmt, ir_set: irast.Set, props_only: bool, is_insert: bool, shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN, source_typeref: irast.TypeRef, wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr, iterator: Optional[pgast.IteratorCTE], ctx: context.CompilerContextLevel, ) -> pgast.CommonTableExpr: """Perform updates to a link relation as part of a DML statement. :param ir_stmt: IR of the statement. :param ir_set: IR of the INSERT/UPDATE body element. :param props_only: Whether this link update only touches link properties. :param wrapper: Top-level SQL query. :param dml_cte: CTE representing the SQL INSERT or UPDATE to the main relation of the Object. :param iterator: IR and CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ toplevel = ctx.toplevel_stmt rptr = ir_set.rptr ptrref = rptr.ptrref assert isinstance(ptrref, irast.PointerRef) target_is_scalar = irtyputils.is_scalar(ir_set.typeref) path_id = ir_set.path_id # The links in the dml class shape have been derived, # but we must use the correct specialized link class for the # base material type. mptrref = irtyputils.find_actual_ptrref(source_typeref, ptrref) assert isinstance(mptrref, irast.PointerRef) target_rvar = relctx.range_for_ptrref(mptrref, for_mutation=True, only_self=True, ctx=ctx) assert isinstance(target_rvar, pgast.RelRangeVar) assert isinstance(target_rvar.relation, pgast.Relation) target_alias = target_rvar.alias.aliasname dml_cte_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get('m'))) col_data = { 'ptr_item_id': pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)), type_name=pgast.TypeName(name=('uuid', ))), 'source': pathctx.get_rvar_path_identity_var(dml_cte_rvar, ir_stmt.subject.path_id, env=ctx.env) } # Turn the IR of the expression on the right side of := # into a subquery returning records for the link table. data_cte, specified_cols = process_link_values( ir_stmt=ir_stmt, ir_expr=ir_set, col_data=col_data, dml_rvar=dml_cte_rvar, sources=[], source_typeref=source_typeref, target_is_scalar=target_is_scalar, dml_cte=dml_cte, iterator=iterator, ctx=ctx, ) toplevel.ctes.append(data_cte) delqry: Optional[pgast.DeleteStmt] data_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()]), ), ], from_clause=[ pgast.RelRangeVar(relation=data_cte), ], ) if not is_insert and shape_op is not qlast.ShapeOp.APPEND: if shape_op is qlast.ShapeOp.SUBTRACT: data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx) # Drop requested link records. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=astutils.new_binop( lexpr=col_data['source'], op='=', rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ), ), op='AND', rexpr=astutils.new_binop( lexpr=pgast.ColumnRef(name=[target_alias, 'target'], ), op='=', rexpr=pgast.ColumnRef( name=[data_rvar.alias.aliasname, 'target'], ), ), ), using_clause=[ dml_cte_rvar, data_rvar, ], returning_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ]) else: # Drop all previous link records for this source. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=col_data['source'], op='=', rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ), ), using_clause=[dml_cte_rvar], returning_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ]) delcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='d'), query=delqry, ) pathctx.put_path_value_rvar(delcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this removal in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'except', delcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(delcte) else: delqry = None if shape_op is qlast.ShapeOp.SUBTRACT: return data_cte cols = [pgast.ColumnRef(name=[col]) for col in specified_cols] conflict_cols = ['source', 'target', 'ptr_item_id'] if is_insert: conflict_clause = None elif len(cols) == len(conflict_cols) and delqry is not None: # There are no link properties, so we can optimize the # link replacement operation by omitting the overlapping # link rows from deletion. filter_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=['source']), ), pgast.ResTarget(val=pgast.ColumnRef(name=['target']), ), ], from_clause=[pgast.RelRangeVar(relation=data_cte)], ) delqry.where_clause = astutils.extend_binop( delqry.where_clause, astutils.new_binop( lexpr=pgast.ImplicitRowExpr(args=[ pgast.ColumnRef(name=['source']), pgast.ColumnRef(name=['target']), ], ), rexpr=pgast.SubLink( type=pgast.SubLinkType.ALL, expr=filter_select, ), op='!=', )) conflict_clause = pgast.OnConflictClause( action='nothing', infer=pgast.InferClause(index_elems=[ pgast.ColumnRef(name=[col]) for col in conflict_cols ]), ) else: # Inserting rows into the link table may produce cardinality # constraint violations, since the INSERT into the link table # is executed in the snapshot where the above DELETE from # the link table is not visible. Hence, we need to use # the ON CONFLICT clause to resolve this. conflict_inference = [] conflict_exc_row = [] for col in conflict_cols: conflict_inference.append(pgast.ColumnRef(name=[col])) conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col])) conflict_data = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()])) ], from_clause=[pgast.RelRangeVar(relation=data_cte)], where_clause=astutils.new_binop( lexpr=pgast.ImplicitRowExpr(args=conflict_inference), rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row), op='=')) conflict_clause = pgast.OnConflictClause( action='update', infer=pgast.InferClause(index_elems=conflict_inference), target_list=[ pgast.MultiAssignRef(columns=cols, source=conflict_data) ]) updcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='i'), query=pgast.InsertStmt( relation=target_rvar, select_stmt=data_select, cols=cols, on_conflict=conflict_clause, returning_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()])) ])) pathctx.put_path_value_rvar(updcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this insertion in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'union', updcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(updcte) return data_cte
def process_link_values( *, ir_stmt: irast.MutatingStmt, ir_expr: irast.Set, col_data: Mapping[str, pgast.BaseExpr], dml_rvar: pgast.PathRangeVar, sources: Iterable[pgast.BaseRangeVar], source_typeref: irast.TypeRef, target_is_scalar: bool, dml_cte: pgast.CommonTableExpr, iterator: Optional[pgast.IteratorCTE], ctx: context.CompilerContextLevel, ) -> Tuple[pgast.CommonTableExpr, List[str]]: """Unpack data from an update expression into a series of selects. :param ir_expr: IR of the INSERT/UPDATE body element. :param col_data: Expressions used to populate well-known columns of the link table such as `source` and `__type__`. :param sources: A list of relations which must be joined into the data query to resolve expressions in *col_data*. :param target_is_scalar: Whether the link target is an ScalarType. :param iterator: IR and CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ old_dml_count = len(ctx.dml_stmts) with ctx.newscope() as newscope, newscope.newrel() as subrelctx: subrelctx.enclosing_cte_iterator = pgast.IteratorCTE( path_id=ir_stmt.subject.path_id, cte=dml_cte, parent=iterator, is_dml_pseudo_iterator=True) row_query = subrelctx.rel relctx.include_rvar(row_query, dml_rvar, path_id=ir_stmt.subject.path_id, ctx=subrelctx) subrelctx.path_scope[ir_stmt.subject.path_id] = row_query merge_iterator(iterator, row_query, ctx=subrelctx) with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx: input_rel = input_rel_ctx.rel input_rel_ctx.expr_exposed = False input_rel_ctx.volatility_ref = pathctx.get_path_identity_var( row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env) dispatch.visit(ir_expr, ctx=input_rel_ctx) if (isinstance(ir_expr.expr, irast.Stmt) and ir_expr.expr.iterator_stmt is not None): # The link value is computaed by a FOR expression, # check if the statement is a DML statement, and if so, # pull the iterator scope so that link property expressions # have the correct context. inner_iterator_cte = None inner_iterator_path_id = ir_expr.expr.iterator_stmt.path_id for cte in input_rel_ctx.toplevel_stmt.ctes: if cte.query.path_id == inner_iterator_path_id: inner_iterator_cte = cte break if inner_iterator_cte is not None: inner_iterator_rvar = relctx.rvar_for_rel( inner_iterator_cte, lateral=True, ctx=subrelctx) relctx.include_rvar( input_rel, inner_iterator_rvar, path_id=inner_iterator_path_id, ctx=subrelctx, ) input_rel_ctx.path_scope[inner_iterator_path_id] = ( input_rel) shape_tuple = None if ir_expr.shape: shape_tuple = shapecomp.compile_shape( ir_expr, [expr for expr, _ in ir_expr.shape], ctx=input_rel_ctx, ) for element in shape_tuple.elements: pathctx.put_path_var_if_not_exists(input_rel_ctx.rel, element.path_id, element.val, aspect='value', env=input_rel_ctx.env) input_stmt: pgast.Query = input_rel input_rvar = pgast.RangeSubselect( subquery=input_rel, lateral=True, alias=pgast.Alias(aliasname=ctx.env.aliases.get('val'))) if len(ctx.dml_stmts) > old_dml_count: # If there were any nested inserts, we need to join them in. pathctx.put_rvar_path_bond(input_rvar, ir_stmt.subject.path_id) relctx.include_rvar(row_query, input_rvar, path_id=ir_stmt.subject.path_id, ctx=ctx) source_data: Dict[str, pgast.BaseExpr] = {} if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None: # UNION input_stmt = input_stmt.rarg path_id = ir_expr.path_id if shape_tuple is not None: for element in shape_tuple.elements: if not element.path_id.is_linkprop_path(): continue val = pathctx.get_rvar_path_value_var(input_rvar, element.path_id, env=ctx.env) rptr = element.path_id.rptr() assert isinstance(rptr, irast.PointerRef) actual_rptr = irtyputils.find_actual_ptrref(source_typeref, rptr) ptr_info = pg_types.get_ptrref_storage_info(actual_rptr) source_data.setdefault(ptr_info.column_name, val) else: if target_is_scalar: target_ref = pathctx.get_rvar_path_value_var(input_rvar, path_id, env=ctx.env) else: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref if not target_is_scalar and 'target' not in source_data: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref specified_cols = [] for col, expr in collections.ChainMap(col_data, source_data).items(): row_query.target_list.append(pgast.ResTarget(val=expr, name=col)) specified_cols.append(col) row_query.from_clause += list(sources) link_rows = pgast.CommonTableExpr(query=row_query, name=ctx.env.aliases.get(hint='r')) return link_rows, specified_cols