def fini_stmt(stmt: pgast.Query, ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) -> None: if stmt is ctx.toplevel_stmt: stmt.argnames = argmap = ctx.argmap if not ctx.env.use_named_params: # Adding unused parameters into a CTE targets = [] for param in ctx.env.query_params: if param.name in argmap: continue if param.name.isdecimal(): idx = int(param.name) + 1 else: idx = len(argmap) + 1 argmap[param.name] = pgast.Param( index=idx, required=param.required, ) targets.append( pgast.ResTarget(val=pgast.TypeCast( arg=pgast.ParamRef(number=idx), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( param.ir_type))))) if targets: ctx.toplevel_stmt.ctes.append( pgast.CommonTableExpr( name="__unused_vars", query=pgast.SelectStmt(target_list=targets)))
def fini_toplevel(stmt: pgast.Query, ctx: context.CompilerContextLevel) -> None: scan_check_ctes(stmt, ctx.env.check_ctes, ctx=ctx) # Type rewrites go first. if stmt.ctes is None: stmt.ctes = [] stmt.ctes[:0] = list(ctx.type_ctes.values()) stmt.argnames = argmap = ctx.argmap if not ctx.env.use_named_params: # Adding unused parameters into a CTE targets = [] for param in ctx.env.query_params: pgparam = argmap[param.name] if pgparam.used: continue targets.append( pgast.ResTarget(val=pgast.TypeCast( arg=pgast.ParamRef(number=pgparam.index), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( param.ir_type))))) if targets: stmt.append_cte( pgast.CommonTableExpr( name="__unused_vars", query=pgast.SelectStmt(target_list=targets)))
def range_for_material_objtype( typeref: irast.TypeRef, path_id: irast.PathId, *, for_mutation: bool = False, include_overlays: bool = True, include_descendants: bool = True, dml_source: Optional[irast.MutatingStmt] = None, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: env = ctx.env if typeref.material_type is not None: typeref = typeref.material_type relation: Union[pgast.Relation, pgast.CommonTableExpr] if ((rewrite := ctx.env.type_rewrites.get(typeref.id)) is not None and typeref.id not in ctx.pending_type_ctes and not for_mutation): if (type_cte := ctx.type_ctes.get(typeref.id)) is None: with ctx.newrel() as sctx: sctx.pending_type_ctes.add(typeref.id) sctx.pending_query = sctx.rel dispatch.visit(rewrite, ctx=sctx) type_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get('t'), query=sctx.rel, materialized=False, ) ctx.type_ctes[typeref.id] = type_cte
def process_linkprop_update(ir_stmt: irast.MutatingStmt, ir_expr: irast.Set, wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr, *, ctx: context.CompilerContextLevel) -> None: """Perform link property updates to a link relation. :param ir_stmt: IR of the statement. :param ir_expr: IR of the UPDATE body element. :param wrapper: Top-level SQL query. :param dml_cte: CTE representing the SQL UPDATE to the main relation of the Object. """ toplevel = ctx.toplevel_stmt rptr = ir_expr.rptr ptrref = rptr.ptrref if ptrref.material_ptr: ptrref = ptrref.material_ptr target_tab = relctx.range_for_ptrref(ptrref, include_overlays=False, ctx=ctx) dml_cte_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get('m'))) cond = astutils.new_binop( pathctx.get_rvar_path_identity_var(dml_cte_rvar, ir_stmt.subject.path_id, env=ctx.env), astutils.get_column(target_tab, 'source', nullable=False), op='=', ) targets = [] for prop_el, shape_op in ir_expr.shape: assert shape_op is qlast.ShapeOp.ASSIGN ptrname = prop_el.rptr.ptrref.shortname with ctx.new() as input_rel_ctx: input_rel_ctx.expr_exposed = False input_rel = dispatch.compile(prop_el.expr, ctx=input_rel_ctx) targets.append(pgast.UpdateTarget(name=ptrname.name, val=input_rel)) updstmt = pgast.UpdateStmt(relation=target_tab, where_clause=cond, targets=targets, from_clause=[dml_cte_rvar]) updcte = pgast.CommonTableExpr(query=updstmt, name=ctx.env.aliases.get( ptrref.shortname.name)) toplevel.ctes.append(updcte)
def scan_check_ctes( stmt: pgast.Query, check_ctes: List[pgast.CommonTableExpr], *, ctx: context.CompilerContextLevel, ) -> None: if not check_ctes: return # Scan all of the check CTEs to enforce constraints that are # checked as explicit queries and not Postgres constraints or # triggers. # To make sure that Postgres can't optimize the checks away, we # reference them in the where clause of an UPDATE to a dummy # table. # Add a big random number, so that different queries should try to # access different "rows" of the table, in case that matters. base_int = random.randint(0, (1 << 60) - 1) val: pgast.BaseExpr = pgast.NumericConstant(val=str(base_int)) for check_cte in check_ctes: # We want the CTE to be MATERIALIZED, because otherwise # Postgres might not fully evaluate all its columns when # scanning it. check_cte.materialized = True check = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.FuncCall(name=('count', ), args=[pgast.Star()]), ) ], from_clause=[ relctx.rvar_for_rel(check_cte, ctx=ctx), ], ) val = pgast.Expr(kind=pgast.ExprKind.OP, name='+', lexpr=val, rexpr=check) update_query = pgast.UpdateStmt( targets=[ pgast.UpdateTarget(name='flag', val=pgast.BooleanConstant(val='true')) ], relation=pgast.RelRangeVar( relation=pgast.Relation(schemaname='edgedb', name='_dml_dummy')), where_clause=pgast.Expr( kind=pgast.ExprKind.OP, name='=', lexpr=pgast.ColumnRef(name=['id']), rexpr=val, )) stmt.append_cte( pgast.CommonTableExpr(query=update_query, name=ctx.env.aliases.get(hint='check_scan')))
def cte_for_query( rel: pgast.Query, *, env: context.Environment) -> pgast.CommonTableExpr: return pgast.CommonTableExpr( query=rel, alias=pgast.Alias( aliasname=env.aliases.get(rel.name) ) )
def compile_insert_else_body(insert_stmt: pgast.InsertStmt, ir_stmt: irast.InsertStmt, on_conflict: irast.OnConflictClause, else_cte_rvar: Optional[Tuple[ pgast.CommonTableExpr, pgast.PathRangeVar]], *, ctx: context.CompilerContextLevel) -> None: infer = None if on_conflict.constraint: constraint_name = f'"{on_conflict.constraint.id};schemaconstr"' infer = pgast.InferClause(conname=constraint_name) insert_stmt.on_conflict = pgast.OnConflictClause( action='nothing', infer=infer, ) if on_conflict.else_ir: else_select, else_branch = on_conflict.else_ir subject_id = ir_stmt.subject.path_id with ctx.newrel() as sctx, sctx.newscope() as ictx: ictx.path_scope[subject_id] = ictx.rel merge_iterator(ctx.enclosing_cte_iterator, ictx.rel, ctx=ictx) pathctx.put_path_bond(ictx.rel, subject_id) dispatch.compile(else_select, ctx=ictx) ictx.rel.view_path_id_map[subject_id] = else_select.path_id else_select_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('else')) ictx.toplevel_stmt.ctes.append(else_select_cte) else_select_rvar = relctx.rvar_for_rel(else_select_cte, ctx=ctx) with ctx.newrel() as sctx, sctx.newscope() as ictx: ictx.path_scope[subject_id] = ictx.rel relctx.include_rvar(ictx.rel, else_select_rvar, path_id=subject_id, ctx=ictx) ictx.enclosing_cte_iterator = pgast.IteratorCTE( path_id=else_select.path_id, cte=else_select_cte, parent=ictx.enclosing_cte_iterator) dispatch.compile(else_branch, ctx=ictx) ictx.rel.view_path_id_map[subject_id] = else_branch.path_id assert else_cte_rvar else_branch_cte = else_cte_rvar[0] else_branch_cte.query = ictx.rel ictx.toplevel_stmt.ctes.append(else_branch_cte)
def get_dml_range( ir_stmt: Union[irast.UpdateStmt, irast.DeleteStmt], *, ctx: context.CompilerContextLevel, ) -> pgast.CommonTableExpr: """Create a range CTE for the given DML statement. :param ir_stmt: IR of the statement. :param dml_stmt: SQL DML node instance. :return: A CommonTableExpr node representing the range affected by the DML statement. """ target_ir_set = ir_stmt.subject ir_qual_expr = ir_stmt.where ir_qual_card = ir_stmt.where_card with ctx.newscope() as scopectx, scopectx.newrel() as subctx: subctx.expr_exposed = False range_stmt = subctx.rel # init_stmt() has associated all top-level paths with # the main query, which is at the very bottom. # Hoist that scope to the modification range statement # instead. for path_id, stmt in ctx.path_scope.items(): if stmt is ctx.rel or path_id == ir_stmt.subject.path_id: scopectx.path_scope[path_id] = range_stmt merge_iterator(ctx.enclosing_cte_iterator, range_stmt, ctx=subctx) dispatch.visit(target_ir_set, ctx=subctx) pathctx.get_path_identity_output(range_stmt, target_ir_set.path_id, env=subctx.env) if ir_qual_expr is not None: range_stmt.where_clause = astutils.extend_binop( range_stmt.where_clause, clauses.compile_filter_clause(ir_qual_expr, ir_qual_card, ctx=subctx)) range_cte = pgast.CommonTableExpr(query=range_stmt, name=ctx.env.aliases.get('range')) return range_cte
def gen_dml_union( ir_stmt: irast.MutatingStmt, parts: DMLParts, *, ctx: context.CompilerContextLevel ) -> Tuple[pgast.CommonTableExpr, pgast.PathRangeVar]: dml_entries = list(parts.dml_ctes.values()) if parts.else_cte: dml_entries.append(parts.else_cte) if len(dml_entries) == 1: union_cte, union_rvar = dml_entries[0] else: union_components = [] for _, dml_rvar in dml_entries: union_component = pgast.SelectStmt() relctx.include_rvar( union_component, dml_rvar, ir_stmt.subject.path_id, ctx=ctx, ) union_components.append(union_component) qry = pgast.SelectStmt( all=True, larg=union_components[0], ) for union_component in union_components[1:]: qry.op = 'UNION' qry.rarg = union_component qry = pgast.SelectStmt( all=True, larg=qry, ) union_cte = pgast.CommonTableExpr(query=qry.larg, name=ctx.env.aliases.get(hint='ma')) union_rvar = relctx.rvar_for_rel( union_cte, typeref=ir_stmt.subject.typeref, ctx=ctx, ) ctx.dml_stmts[ir_stmt] = union_cte return union_cte, union_rvar
def compile_iterator_ctes( iterators: Iterable[irast.Set], *, ctx: context.CompilerContextLevel) -> Optional[pgast.IteratorCTE]: last_iterator = ctx.enclosing_cte_iterator seen = set() p = last_iterator while p: seen.add(p.path_id) p = p.parent for iterator_set in iterators: # Because of how the IR compiler hoists iterators, we may see # an iterator twice. Just ignore it if we do. if iterator_set.path_id in seen: continue with ctx.newrel() as sctx, sctx.newscope() as ictx: ictx.path_scope[iterator_set.path_id] = ictx.rel # Correlate with enclosing iterators merge_iterator(last_iterator, ictx.rel, ctx=ictx) if last_iterator is not None: ictx.volatility_ref = pathctx.get_path_identity_var( ictx.rel, last_iterator.path_id, env=ictx.env) clauses.compile_iterator_expr(ictx.rel, iterator_set, ctx=ictx) ictx.rel.path_id = iterator_set.path_id pathctx.put_path_bond(ictx.rel, iterator_set.path_id) iterator_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('iter')) ictx.toplevel_stmt.ctes.append(iterator_cte) last_iterator = pgast.IteratorCTE(path_id=iterator_set.path_id, cte=iterator_cte, parent=last_iterator) return last_iterator
def process_insert_body(ir_stmt: irast.MutatingStmt, wrapper: pgast.SelectStmt, insert_cte: pgast.CommonTableExpr, insert_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> None: """Generate SQL DML CTEs from an InsertStmt IR. :param ir_stmt: IR of the statement. :param wrapper: Top-level SQL query. :param insert_cte: CTE representing the SQL INSERT to the main relation of the Object. """ cols = [pgast.ColumnRef(name=['__type__'])] select = pgast.SelectStmt(target_list=[]) values = select.target_list # The main INSERT query of this statement will always be # present to insert at least the `id` and `__type__` # properties. insert_stmt = insert_cte.query assert isinstance(insert_stmt, pgast.InsertStmt) insert_stmt.cols = cols insert_stmt.select_stmt = select if ir_stmt.parent_stmt is not None: iterator_set = ir_stmt.parent_stmt.iterator_stmt else: iterator_set = None iterator_cte: Optional[pgast.CommonTableExpr] iterator_id: Optional[pgast.BaseExpr] if iterator_set is not None: with ctx.substmt() as ictx: ictx.path_scope = ictx.path_scope.new_child() ictx.path_scope[iterator_set.path_id] = ictx.rel clauses.compile_iterator_expr(ictx.rel, iterator_set, ctx=ictx) ictx.rel.path_id = iterator_set.path_id pathctx.put_path_bond(ictx.rel, iterator_set.path_id) iterator_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('iter')) ictx.toplevel_stmt.ctes.append(iterator_cte) iterator_rvar = relctx.rvar_for_rel(iterator_cte, ctx=ctx) relctx.include_rvar(select, iterator_rvar, path_id=ictx.rel.path_id, ctx=ctx) iterator_id = pathctx.get_path_identity_var(select, iterator_set.path_id, env=ctx.env) else: iterator_cte = None iterator_id = None typeref = ir_stmt.subject.typeref if typeref.material_type is not None: typeref = typeref.material_type values.append( pgast.ResTarget(val=pgast.TypeCast( arg=pgast.StringConstant(val=str(typeref.id)), type_name=pgast.TypeName(name=('uuid', ))), )) external_inserts = [] with ctx.newrel() as subctx: subctx.rel = select subctx.rel_hierarchy[select] = insert_stmt subctx.expr_exposed = False if iterator_cte is not None: subctx.path_scope = ctx.path_scope.new_child() subctx.path_scope[iterator_cte.query.path_id] = select # Process the Insert IR and separate links that go # into the main table from links that are inserted into # a separate link table. for shape_el, shape_op in ir_stmt.subject.shape: assert shape_op is qlast.ShapeOp.ASSIGN rptr = shape_el.rptr ptrref = rptr.ptrref if ptrref.material_ptr is not None: ptrref = ptrref.material_ptr if (ptrref.source_ptr is not None and rptr.source.path_id != ir_stmt.subject.path_id): continue ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=True, link_bias=False) props_only = False # First, process all local link inserts. if ptr_info.table_type == 'ObjectType': props_only = True field = pgast.ColumnRef(name=[ptr_info.column_name]) cols.append(field) rel = compile_insert_shape_element(insert_stmt, wrapper, ir_stmt, shape_el, iterator_id, ctx=ctx) insvalue = pathctx.get_path_value_var(rel, shape_el.path_id, env=ctx.env) if irtyputils.is_tuple(shape_el.typeref): # Tuples require an explicit cast. insvalue = pgast.TypeCast( arg=output.output_as_value(insvalue, env=ctx.env), type_name=pgast.TypeName(name=ptr_info.column_type, ), ) values.append(pgast.ResTarget(val=insvalue)) ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=False, link_bias=True) if ptr_info and ptr_info.table_type == 'link': external_inserts.append((shape_el, props_only)) if iterator_set is not None: cols.append(pgast.ColumnRef(name=['__edb_token'])) values.append(pgast.ResTarget(val=iterator_id)) pathctx.put_path_identity_var(insert_stmt, iterator_set.path_id, cols[-1], force=True, env=subctx.env) pathctx.put_path_bond(insert_stmt, iterator_set.path_id) pathctx.put_path_rvar( wrapper, path_id=iterator_set.path_id, rvar=insert_rvar, aspect='identity', env=subctx.env, ) if isinstance(ir_stmt, irast.InsertStmt) and ir_stmt.on_conflict: assert not insert_stmt.on_conflict constraint_name = f'"{ir_stmt.on_conflict.id};schemaconstr"' insert_stmt.on_conflict = pgast.OnConflictClause( action='nothing', infer=pgast.InferClause(conname=constraint_name), ) toplevel = ctx.toplevel_stmt toplevel.ctes.append(insert_cte) # Process necessary updates to the link tables. for shape_el, props_only in external_inserts: process_link_update( ir_stmt=ir_stmt, ir_set=shape_el, props_only=props_only, wrapper=wrapper, dml_cte=insert_cte, source_typeref=typeref, iterator_cte=iterator_cte, is_insert=True, ctx=ctx, )
def process_link_values( ir_stmt, ir_expr, target_tab, col_data, dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *, ctx=context.CompilerContext) -> \ typing.Tuple[pgast.CommonTableExpr, typing.List[str]]: """Unpack data from an update expression into a series of selects. :param ir_expr: IR of the INSERT/UPDATE body element. :param target_tab: The link table being updated. :param col_data: Expressions used to populate well-known columns of the link table such as `source` and `__type__`. :param sources: A list of relations which must be joined into the data query to resolve expressions in *col_data*. :param props_only: Whether this link update only touches link properties. :param target_is_scalar: Whether the link target is an ScalarType. :param iterator_cte: CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ with ctx.newscope() as newscope, newscope.newrel() as subrelctx: row_query = subrelctx.rel relctx.include_rvar(row_query, dml_rvar, path_id=ir_stmt.subject.path_id, ctx=subrelctx) subrelctx.path_scope[ir_stmt.subject.path_id] = row_query if iterator_cte is not None: iterator_rvar = relctx.rvar_for_rel(iterator_cte, lateral=True, ctx=subrelctx) relctx.include_rvar(row_query, iterator_rvar, path_id=iterator_cte.query.path_id, ctx=subrelctx) with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx: input_rel = input_rel_ctx.rel if iterator_cte is not None: input_rel_ctx.path_scope[iterator_cte.query.path_id] = \ row_query input_rel_ctx.expr_exposed = False input_rel_ctx.volatility_ref = pathctx.get_path_identity_var( row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env) dispatch.visit(ir_expr, ctx=input_rel_ctx) shape_tuple = None if ir_expr.shape: shape_tuple = shapecomp.compile_shape(ir_expr, ir_expr.shape, ctx=input_rel_ctx) for element in shape_tuple.elements: pathctx.put_path_var_if_not_exists(input_rel_ctx.rel, element.path_id, element.val, aspect='value', env=input_rel_ctx.env) input_stmt = input_rel input_rvar = pgast.RangeSubselect( subquery=input_rel, lateral=True, alias=pgast.Alias(aliasname=ctx.env.aliases.get('val'))) source_data: typing.Dict[str, pgast.BaseExpr] = {} if input_stmt.op is not None: # UNION input_stmt = input_stmt.rarg path_id = ir_expr.path_id if shape_tuple is not None: for element in shape_tuple.elements: if not element.path_id.is_linkprop_path(): continue colname = element.path_id.rptr_name().name val = pathctx.get_rvar_path_value_var(input_rvar, element.path_id, env=ctx.env) source_data.setdefault(colname, val) else: if target_is_scalar: target_ref = pathctx.get_rvar_path_value_var(input_rvar, path_id, env=ctx.env) else: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref if not target_is_scalar and 'target' not in source_data: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref specified_cols = [] for col, expr in collections.ChainMap(col_data, source_data).items(): row_query.target_list.append(pgast.ResTarget(val=expr, name=col)) specified_cols.append(col) row_query.from_clause += list(sources) + [input_rvar] link_rows = pgast.CommonTableExpr(query=row_query, name=ctx.env.aliases.get(hint='r')) return link_rows, specified_cols
def process_link_update( *, ir_stmt: irast.MutatingStmt, ir_set: irast.Set, props_only: bool, is_insert: bool, wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr, iterator_cte: typing.Optional[pgast.CommonTableExpr], ctx: context.CompilerContextLevel) -> pgast.CommonTableExpr: """Perform updates to a link relation as part of a DML statement. :param ir_stmt: IR of the statement. :param ir_set: IR of the INSERT/UPDATE body element. :param props_only: Whether this link update only touches link properties. :param wrapper: Top-level SQL query. :param dml_cte: CTE representing the SQL INSERT or UPDATE to the main relation of the Object. :param iterator_cte: CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ toplevel = ctx.toplevel_stmt rptr = ir_set.rptr ptrref = rptr.ptrref assert isinstance(ptrref, irast.PointerRef) target_is_scalar = irtyputils.is_scalar(ptrref.dir_target) path_id = ir_set.path_id # The links in the dml class shape have been derived, # but we must use the correct specialized link class for the # base material type. if ptrref.material_ptr is not None: mptrref = ptrref.material_ptr assert isinstance(mptrref, irast.PointerRef) else: mptrref = ptrref target_rvar = relctx.range_for_ptrref(mptrref, include_overlays=False, only_self=True, ctx=ctx) assert isinstance(target_rvar, pgast.RelRangeVar) assert isinstance(target_rvar.relation, pgast.Relation) target_alias = target_rvar.alias.aliasname target_tab_name = (target_rvar.relation.schemaname, target_rvar.relation.name) dml_cte_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get('m'))) col_data = { 'ptr_item_id': pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)), type_name=pgast.TypeName(name=('uuid', ))), 'source': pathctx.get_rvar_path_identity_var(dml_cte_rvar, ir_stmt.subject.path_id, env=ctx.env) } if not is_insert: # Drop all previous link records for this source. delcte = pgast.CommonTableExpr(query=pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=col_data['source'], op='=', rexpr=pgast.ColumnRef(name=[target_alias, 'source'])), using_clause=[dml_cte_rvar], returning_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[target_alias, pgast.Star()])) ]), name=ctx.env.aliases.get(hint='d')) pathctx.put_path_value_rvar(delcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this removal in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'except', delcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(delcte) # Turn the IR of the expression on the right side of := # into a subquery returning records for the link table. data_cte, specified_cols = process_link_values(ir_stmt, ir_set, target_tab_name, col_data, dml_cte_rvar, [], props_only, target_is_scalar, iterator_cte, ctx=ctx) toplevel.ctes.append(data_cte) data_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()])) ], from_clause=[pgast.RelRangeVar(relation=data_cte)]) cols = [pgast.ColumnRef(name=[col]) for col in specified_cols] if is_insert: conflict_clause = None else: # Inserting rows into the link table may produce cardinality # constraint violations, since the INSERT into the link table # is executed in the snapshot where the above DELETE from # the link table is not visible. Hence, we need to use # the ON CONFLICT clause to resolve this. conflict_cols = ['source', 'target', 'ptr_item_id'] conflict_inference = [] conflict_exc_row = [] for col in conflict_cols: conflict_inference.append(pgast.ColumnRef(name=[col])) conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col])) conflict_data = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()])) ], from_clause=[pgast.RelRangeVar(relation=data_cte)], where_clause=astutils.new_binop( lexpr=pgast.ImplicitRowExpr(args=conflict_inference), rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row), op='=')) conflict_clause = pgast.OnConflictClause( action='update', infer=pgast.InferClause(index_elems=conflict_inference), target_list=[ pgast.MultiAssignRef(columns=cols, source=conflict_data) ]) updcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='i'), query=pgast.InsertStmt( relation=target_rvar, select_stmt=data_select, cols=cols, on_conflict=conflict_clause, returning_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()])) ])) pathctx.put_path_value_rvar(updcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this insertion in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'union', updcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(updcte) return data_cte
def init_dml_stmt( ir_stmt: irast.MutatingStmt, dml_stmt: pgast.DMLQuery, *, ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) \ -> typing.Tuple[pgast.Query, pgast.CommonTableExpr, pgast.PathRangeVar, typing.Optional[pgast.CommonTableExpr]]: """Prepare the common structure of the query representing a DML stmt. :param ir_stmt: IR of the statement. :param dml_stmt: SQL DML node instance. :return: A (*wrapper*, *dml_cte*, *range_cte*) tuple, where *wrapper* the the wrapping SQL statement, *dml_cte* is the CTE representing the SQL DML operation in the main relation of the Object, and *range_cte* is the CTE for the subset affected by the statement. *range_cte* is None for INSERT statmenets. """ wrapper = ctx.rel clauses.init_stmt(ir_stmt, ctx, parent_ctx) target_ir_set = ir_stmt.subject dml_stmt.relation = relctx.range_for_typeref( ir_stmt.subject.typeref, ir_stmt.subject.path_id, include_overlays=False, common_parent=True, ctx=ctx, ) pathctx.put_path_value_rvar(dml_stmt, target_ir_set.path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_source_rvar(dml_stmt, target_ir_set.path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_bond(dml_stmt, target_ir_set.path_id) dml_cte = pgast.CommonTableExpr(query=dml_stmt, name=ctx.env.aliases.get(hint='m')) range_cte = None if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate # the corresponding CTE and connect it to the DML query. range_cte = get_dml_range(ir_stmt, dml_stmt, ctx=ctx) range_rvar = pgast.RelRangeVar( relation=range_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range'))) relctx.pull_path_namespace(target=dml_stmt, source=range_rvar, ctx=ctx) # Auxiliary relations are always joined via the WHERE # clause due to the structure of the UPDATE/DELETE SQL statements. dml_stmt.where_clause = astutils.new_binop( lexpr=pgast.ColumnRef( name=[dml_stmt.relation.alias.aliasname, 'id']), op='=', rexpr=pathctx.get_rvar_path_identity_var(range_rvar, target_ir_set.path_id, env=ctx.env)) # UPDATE has "FROM", while DELETE has "USING". if isinstance(dml_stmt, pgast.UpdateStmt): dml_stmt.from_clause.append(range_rvar) elif isinstance(dml_stmt, pgast.DeleteStmt): dml_stmt.using_clause.append(range_rvar) # Due to the fact that DML statements are structured # as a flat list of CTEs instead of nested range vars, # the top level path scope must be empty. The necessary # range vars will be injected explicitly in all rels that # need them. ctx.path_scope.clear() pathctx.put_path_value_rvar(dml_stmt, ir_stmt.subject.path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_source_rvar(dml_stmt, ir_stmt.subject.path_id, dml_stmt.relation, env=ctx.env) dml_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias(aliasname=parent_ctx.env.aliases.get('d'))) relctx.include_rvar(wrapper, dml_rvar, ir_stmt.subject.path_id, ctx=ctx) pathctx.put_path_bond(wrapper, ir_stmt.subject.path_id) return wrapper, dml_cte, dml_rvar, range_cte
def process_insert_body(ir_stmt: irast.MutatingStmt, wrapper: pgast.Query, insert_cte: pgast.CommonTableExpr, insert_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> None: """Generate SQL DML CTEs from an InsertStmt IR. :param ir_stmt: IR of the statement. :param wrapper: Top-level SQL query. :param insert_cte: CTE representing the SQL INSERT to the main relation of the Object. """ cols = [pgast.ColumnRef(name=['__type__'])] select = pgast.SelectStmt(target_list=[]) values = select.target_list # The main INSERT query of this statement will always be # present to insert at least the `id` and `__type__` # properties. insert_stmt = insert_cte.query assert isinstance(insert_stmt, pgast.InsertStmt) insert_stmt.cols = cols insert_stmt.select_stmt = select if ir_stmt.parent_stmt is not None: iterator_set = ir_stmt.parent_stmt.iterator_stmt else: iterator_set = None if iterator_set is not None: with ctx.substmt() as ictx: ictx.path_scope = ictx.path_scope.new_child() ictx.path_scope[iterator_set.path_id] = ictx.rel clauses.compile_iterator_expr(ictx.rel, iterator_set, ctx=ictx) ictx.rel.path_id = iterator_set.path_id pathctx.put_path_bond(ictx.rel, iterator_set.path_id) iterator_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('iter')) ictx.toplevel_stmt.ctes.append(iterator_cte) iterator_rvar = relctx.rvar_for_rel(iterator_cte, ctx=ctx) relctx.include_rvar(select, iterator_rvar, path_id=ictx.rel.path_id, ctx=ctx) iterator_id = pathctx.get_path_identity_var(select, iterator_set.path_id, env=ctx.env) else: iterator_cte = None iterator_id = None typeref = ir_stmt.subject.typeref if typeref.material_type is not None: typeref = typeref.material_type values.append( pgast.ResTarget(val=pgast.TypeCast( arg=pgast.StringConstant(val=str(typeref.id)), type_name=pgast.TypeName(name=('uuid', ))), )) external_inserts = [] parent_link_props = [] with ctx.newrel() as subctx: subctx.rel = select subctx.rel_hierarchy[select] = insert_stmt subctx.expr_exposed = False if iterator_cte is not None: subctx.path_scope = ctx.path_scope.new_child() subctx.path_scope[iterator_cte.query.path_id] = select # Process the Insert IR and separate links that go # into the main table from links that are inserted into # a separate link table. for shape_el in ir_stmt.subject.shape: rptr = shape_el.rptr ptrref = rptr.ptrref if ptrref.material_ptr is not None: ptrref = ptrref.material_ptr if (ptrref.parent_ptr is not None and rptr.source.path_id != ir_stmt.subject.path_id): parent_link_props.append(shape_el) continue ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=True, link_bias=False) props_only = False # First, process all local link inserts. if ptr_info.table_type == 'ObjectType': props_only = True field = pgast.ColumnRef(name=[ptr_info.column_name]) cols.append(field) insvalue = insert_value_for_shape_element(insert_stmt, wrapper, ir_stmt, shape_el, iterator_id, ptr_info=ptr_info, ctx=subctx) values.append(pgast.ResTarget(val=insvalue)) ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=False, link_bias=True) if ptr_info and ptr_info.table_type == 'link': external_inserts.append((shape_el, props_only)) if iterator_cte is not None: cols.append(pgast.ColumnRef(name=['__edb_token'])) values.append(pgast.ResTarget(val=iterator_id)) pathctx.put_path_identity_var(insert_stmt, iterator_set.path_id, cols[-1], force=True, env=subctx.env) pathctx.put_path_bond(insert_stmt, iterator_set.path_id) toplevel = ctx.toplevel_stmt toplevel.ctes.append(insert_cte) # Process necessary updates to the link tables. for shape_el, props_only in external_inserts: process_link_update(ir_stmt=ir_stmt, ir_set=shape_el, props_only=props_only, wrapper=wrapper, dml_cte=insert_cte, iterator_cte=iterator_cte, is_insert=True, ctx=ctx) if parent_link_props: prop_elements = [] with ctx.newscope() as scopectx: scopectx.rel = wrapper for shape_el in parent_link_props: rptr = shape_el.rptr scopectx.path_scope[rptr.source.path_id] = wrapper pathctx.put_path_rvar_if_not_exists(wrapper, rptr.source.path_id, insert_rvar, aspect='value', env=scopectx.env) dispatch.visit(shape_el, ctx=scopectx) tuple_el = astutils.tuple_element_for_shape_el(shape_el, None, ctx=scopectx) prop_elements.append(tuple_el) valtuple = pgast.TupleVar(elements=prop_elements, named=True) pathctx.put_path_value_var(wrapper, ir_stmt.subject.path_id, valtuple, force=True, env=ctx.env)
def compile_ConfigSet( op: irast.ConfigSet, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: val = _compile_config_value(op, ctx=ctx) result: pgast.BaseExpr if op.scope is qltypes.ConfigScope.INSTANCE and op.backend_setting: if not ctx.env.backend_runtime_params.has_configfile_access: raise errors.UnsupportedBackendFeatureError( "configuring backend parameters via CONFIGURE INSTANCE" " is not supported by the current backend") result = pgast.AlterSystem( name=op.backend_setting, value=val, ) elif op.scope is qltypes.ConfigScope.DATABASE and op.backend_setting: if not isinstance(val, pgast.StringConstant): val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('text', )), ) fcall = pgast.FuncCall( name=('edgedb', '_alter_current_database_set'), args=[pgast.StringConstant(val=op.backend_setting), val], ) result = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.SESSION and op.backend_setting: if not isinstance(val, pgast.StringConstant): val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('text', )), ) fcall = pgast.FuncCall( name=('pg_catalog', 'set_config'), args=[ pgast.StringConstant(val=op.backend_setting), val, pgast.BooleanConstant(val='false'), ], ) result = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.INSTANCE: result_row = pgast.RowExpr(args=[ pgast.StringConstant(val='SET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), val, ]) result = pgast.FuncCall( name=('jsonb_build_array', ), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt(target_list=[ pgast.ResTarget(val=result, ), ], ) elif op.scope in (qltypes.ConfigScope.SESSION, qltypes.ConfigScope.GLOBAL): flag = 'G' if op.scope is qltypes.ConfigScope.GLOBAL else 'C' result = pgast.InsertStmt( relation=pgast.RelRangeVar(relation=pgast.Relation( name='_edgecon_state', ), ), select_stmt=pgast.SelectStmt(values=[ pgast.ImplicitRowExpr(args=[ pgast.StringConstant(val=op.name, ), val, pgast.StringConstant(val=flag, ), ]) ]), cols=[ pgast.ColumnRef(name=['name']), pgast.ColumnRef(name=['value']), pgast.ColumnRef(name=['type']), ], on_conflict=pgast.OnConflictClause( action='update', infer=pgast.InferClause(index_elems=[ pgast.ColumnRef(name=['name']), pgast.ColumnRef(name=['type']), ], ), target_list=[ pgast.MultiAssignRef( columns=[pgast.ColumnRef(name=['value'])], source=pgast.RowExpr(args=[ val, ], ), ), ], ), ) if op.scope is qltypes.ConfigScope.GLOBAL: result_row = pgast.RowExpr(args=[ pgast.StringConstant(val='SET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), val, ]) build_array = pgast.FuncCall( name=('jsonb_build_array', ), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt( ctes=[pgast.CommonTableExpr( name='ins', query=result, )], target_list=[pgast.ResTarget(val=build_array)], ) elif op.scope is qltypes.ConfigScope.DATABASE: result = pgast.InsertStmt( relation=pgast.RelRangeVar(relation=pgast.Relation( name='_db_config', schemaname='edgedb', ), ), select_stmt=pgast.SelectStmt(values=[ pgast.ImplicitRowExpr(args=[ pgast.StringConstant(val=op.name, ), val, ]) ]), cols=[ pgast.ColumnRef(name=['name']), pgast.ColumnRef(name=['value']), ], on_conflict=pgast.OnConflictClause( action='update', infer=pgast.InferClause(index_elems=[ pgast.ColumnRef(name=['name']), ], ), target_list=[ pgast.MultiAssignRef( columns=[pgast.ColumnRef(name=['value'])], source=pgast.RowExpr(args=[ val, ], ), ), ], ), ) else: raise AssertionError(f'unexpected configuration scope: {op.scope}') return result
def init_dml_stmt( ir_stmt: irast.MutatingStmt, *, ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel, ) -> DMLParts: """Prepare the common structure of the query representing a DML stmt. :param ir_stmt: IR of the DML statement. :return: A DMLParts tuple containing a map of DML CTEs as well as the common range CTE for UPDATE/DELETE statements. """ clauses.init_stmt(ir_stmt, ctx, parent_ctx) range_cte: Optional[pgast.CommonTableExpr] range_rvar: Optional[pgast.RelRangeVar] if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate # the corresponding CTE and connect it to the DML stetements. range_cte = get_dml_range(ir_stmt, ctx=ctx) range_rvar = pgast.RelRangeVar( relation=range_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range'))) else: range_cte = None range_rvar = None top_typeref = ir_stmt.subject.typeref if top_typeref.material_type: top_typeref = top_typeref.material_type typerefs = [top_typeref] if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): if top_typeref.union: for component in top_typeref.union: if component.material_type: component = component.material_type typerefs.append(component) typerefs.extend(irtyputils.get_typeref_descendants(component)) typerefs.extend(irtyputils.get_typeref_descendants(top_typeref)) dml_map = {} for typeref in typerefs: dml_cte, dml_rvar = gen_dml_cte( ir_stmt, range_rvar=range_rvar, typeref=typeref, ctx=ctx, ) dml_map[typeref] = (dml_cte, dml_rvar) else_cte = None if (isinstance(ir_stmt, irast.InsertStmt) and ir_stmt.on_conflict and ir_stmt.on_conflict[1] is not None): dml_cte = pgast.CommonTableExpr(query=pgast.SelectStmt(), name=ctx.env.aliases.get(hint='m')) dml_rvar = relctx.rvar_for_rel(dml_cte, ctx=ctx) else_cte = (dml_cte, dml_rvar) pathctx.put_path_bond(ctx.rel, ir_stmt.subject.path_id) if ctx.enclosing_cte_iterator: pathctx.put_path_bond(ctx.rel, ctx.enclosing_cte_iterator.path_id) return DMLParts( dml_ctes=dml_map, range_cte=range_cte, else_cte=else_cte, )
def process_link_update( *, ir_stmt: irast.MutatingStmt, ir_set: irast.Set, props_only: bool, is_insert: bool, shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN, source_typeref: irast.TypeRef, wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr, iterator_cte: Optional[pgast.CommonTableExpr], ctx: context.CompilerContextLevel, ) -> pgast.CommonTableExpr: """Perform updates to a link relation as part of a DML statement. :param ir_stmt: IR of the statement. :param ir_set: IR of the INSERT/UPDATE body element. :param props_only: Whether this link update only touches link properties. :param wrapper: Top-level SQL query. :param dml_cte: CTE representing the SQL INSERT or UPDATE to the main relation of the Object. :param iterator_cte: CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ toplevel = ctx.toplevel_stmt rptr = ir_set.rptr ptrref = rptr.ptrref assert isinstance(ptrref, irast.PointerRef) target_is_scalar = irtyputils.is_scalar(ir_set.typeref) path_id = ir_set.path_id # The links in the dml class shape have been derived, # but we must use the correct specialized link class for the # base material type. if ptrref.material_ptr is not None: mptrref = ptrref.material_ptr else: mptrref = ptrref if mptrref.out_source.id != source_typeref.id: for descendant in mptrref.descendants: if descendant.out_source.id == source_typeref.id: mptrref = descendant break else: raise errors.InternalServerError( 'missing PointerRef descriptor for source typeref') assert isinstance(mptrref, irast.PointerRef) target_rvar = relctx.range_for_ptrref(mptrref, for_mutation=True, only_self=True, ctx=ctx) assert isinstance(target_rvar, pgast.RelRangeVar) assert isinstance(target_rvar.relation, pgast.Relation) target_alias = target_rvar.alias.aliasname target_tab_name = (target_rvar.relation.schemaname, target_rvar.relation.name) dml_cte_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get('m'))) col_data = { 'ptr_item_id': pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)), type_name=pgast.TypeName(name=('uuid', ))), 'source': pathctx.get_rvar_path_identity_var(dml_cte_rvar, ir_stmt.subject.path_id, env=ctx.env) } # Turn the IR of the expression on the right side of := # into a subquery returning records for the link table. data_cte, specified_cols = process_link_values( ir_stmt=ir_stmt, ir_expr=ir_set, target_tab=target_tab_name, col_data=col_data, dml_rvar=dml_cte_rvar, sources=[], props_only=props_only, target_is_scalar=target_is_scalar, iterator_cte=iterator_cte, ctx=ctx, ) toplevel.ctes.append(data_cte) delqry: Optional[pgast.DeleteStmt] data_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()]), ), ], from_clause=[ pgast.RelRangeVar(relation=data_cte), ], ) if not is_insert and shape_op is not qlast.ShapeOp.APPEND: if shape_op is qlast.ShapeOp.SUBTRACT: data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx) # Drop requested link records. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=astutils.new_binop( lexpr=col_data['source'], op='=', rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ), ), op='AND', rexpr=astutils.new_binop( lexpr=pgast.ColumnRef(name=[target_alias, 'target'], ), op='=', rexpr=pgast.ColumnRef( name=[data_rvar.alias.aliasname, 'target'], ), ), ), using_clause=[ dml_cte_rvar, data_rvar, ], returning_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ]) else: # Drop all previous link records for this source. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=col_data['source'], op='=', rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ), ), using_clause=[dml_cte_rvar], returning_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ]) delcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='d'), query=delqry, ) pathctx.put_path_value_rvar(delcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this removal in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'except', delcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(delcte) else: delqry = None if shape_op is qlast.ShapeOp.SUBTRACT: return data_cte cols = [pgast.ColumnRef(name=[col]) for col in specified_cols] conflict_cols = ['source', 'target', 'ptr_item_id'] if is_insert: conflict_clause = None elif len(cols) == len(conflict_cols) and delqry is not None: # There are no link properties, so we can optimize the # link replacement operation by omitting the overlapping # link rows from deletion. filter_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=['source']), ), pgast.ResTarget(val=pgast.ColumnRef(name=['target']), ), ], from_clause=[pgast.RelRangeVar(relation=data_cte)], ) delqry.where_clause = astutils.extend_binop( delqry.where_clause, astutils.new_binop( lexpr=pgast.ImplicitRowExpr(args=[ pgast.ColumnRef(name=['source']), pgast.ColumnRef(name=['target']), ], ), rexpr=pgast.SubLink( type=pgast.SubLinkType.ALL, expr=filter_select, ), op='!=', )) conflict_clause = pgast.OnConflictClause( action='nothing', infer=pgast.InferClause(index_elems=[ pgast.ColumnRef(name=[col]) for col in conflict_cols ]), ) else: # Inserting rows into the link table may produce cardinality # constraint violations, since the INSERT into the link table # is executed in the snapshot where the above DELETE from # the link table is not visible. Hence, we need to use # the ON CONFLICT clause to resolve this. conflict_inference = [] conflict_exc_row = [] for col in conflict_cols: conflict_inference.append(pgast.ColumnRef(name=[col])) conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col])) conflict_data = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()])) ], from_clause=[pgast.RelRangeVar(relation=data_cte)], where_clause=astutils.new_binop( lexpr=pgast.ImplicitRowExpr(args=conflict_inference), rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row), op='=')) conflict_clause = pgast.OnConflictClause( action='update', infer=pgast.InferClause(index_elems=conflict_inference), target_list=[ pgast.MultiAssignRef(columns=cols, source=conflict_data) ]) updcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='i'), query=pgast.InsertStmt( relation=target_rvar, select_stmt=data_select, cols=cols, on_conflict=conflict_clause, returning_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()])) ])) pathctx.put_path_value_rvar(updcte.query, path_id.ptr_path(), target_rvar, env=ctx.env) # Record the effect of this insertion in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx) relctx.add_ptr_rel_overlay(ptrref, 'union', updcte, dml_stmts=dml_stack, ctx=ctx) toplevel.ctes.append(updcte) return data_cte
def init_dml_stmt( ir_stmt: irast.MutatingStmt, *, ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel, ) -> DMLParts: """Prepare the common structure of the query representing a DML stmt. :param ir_stmt: IR of the DML statement. :return: A DMLParts tuple containing a map of DML CTEs as well as the common range CTE for UPDATE/DELETE statements. """ clauses.init_stmt(ir_stmt, ctx, parent_ctx) range_cte: Optional[pgast.CommonTableExpr] range_rvar: Optional[pgast.RelRangeVar] if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate # the corresponding CTE and connect it to the DML stetements. range_cte = get_dml_range(ir_stmt, ctx=ctx) range_rvar = pgast.RelRangeVar( relation=range_cte, alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range'))) else: range_cte = None range_rvar = None top_typeref = ir_stmt.subject.typeref if top_typeref.material_type: top_typeref = top_typeref.material_type typerefs = [top_typeref] if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): if top_typeref.union: for component in top_typeref.union: if component.material_type: component = component.material_type typerefs.append(component) if component.descendants: typerefs.extend(component.descendants) if top_typeref.descendants: typerefs.extend(top_typeref.descendants) dml_map = {} for typeref in typerefs: dml_cte, dml_rvar = gen_dml_cte( ir_stmt, range_rvar=range_rvar, typeref=typeref, ctx=ctx, ) dml_map[typeref] = (dml_cte, dml_rvar) if len(dml_map) == 1: union_cte, union_rvar = next(iter(dml_map.values())) else: union_components = [] for _, dml_rvar in dml_map.values(): union_component = pgast.SelectStmt() relctx.include_rvar( union_component, dml_rvar, ir_stmt.subject.path_id, ctx=ctx, ) union_components.append(union_component) qry = pgast.SelectStmt( all=True, larg=union_components[0], ) for union_component in union_components[1:]: qry.op = 'UNION' qry.rarg = union_component qry = pgast.SelectStmt( all=True, larg=qry, ) union_cte = pgast.CommonTableExpr(query=qry.larg, name=ctx.env.aliases.get(hint='ma')) union_rvar = relctx.rvar_for_rel( union_cte, typeref=ir_stmt.subject.typeref, ctx=ctx, ) relctx.include_rvar(ctx.rel, union_rvar, ir_stmt.subject.path_id, ctx=ctx) pathctx.put_path_bond(ctx.rel, ir_stmt.subject.path_id) ctx.dml_stmts[ir_stmt] = union_cte return DMLParts(dml_ctes=dml_map, range_cte=range_cte, union_cte=union_cte)
def compile_GroupStmt(stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel) -> pgast.Query: parent_ctx = ctx with parent_ctx.substmt() as ctx: clauses.init_stmt(stmt, ctx=ctx, parent_ctx=parent_ctx) group_path_id = stmt.group_path_id # Process the GROUP .. BY part into a subquery. with ctx.subrel() as gctx: gctx.expr_exposed = False gquery = gctx.rel pathctx.put_path_bond(gquery, group_path_id) relctx.update_scope(stmt.subject, gquery, ctx=gctx) stmt.subject.path_scope = None clauses.compile_output(stmt.subject, ctx=gctx) subj_rvar = pathctx.get_path_rvar(gquery, stmt.subject.path_id, aspect='value', env=gctx.env) relctx.ensure_bond_for_expr(stmt.subject, subj_rvar.query, ctx=gctx) group_paths = set() part_clause: List[pgast.BaseExpr] = [] for ir_gexpr in stmt.groupby: with gctx.new() as subctx: partexpr = dispatch.compile(ir_gexpr, ctx=subctx) part_clause.append(partexpr) group_paths.add(ir_gexpr.path_id) # Since we will be computing arbitrary expressions # based on the grouped sets, it is more efficient # to compute the "group bond" as a small unique # value than it is to use GROUP BY and aggregate # actual id values into an array. # # To achieve this we use the first_value() window # function while using the GROUP BY clause as # a partition clause. We use the id of the first # object in each partition if GROUP BY input is # a ObjectType, otherwise we generate the id using # row_number(). if stmt.subject.path_id.is_objtype_path(): first_val = pathctx.get_path_identity_var(gquery, stmt.subject.path_id, env=ctx.env) else: with ctx.subrel() as subctx: wrapper = subctx.rel gquery_rvar = relctx.rvar_for_rel(gquery, ctx=subctx) wrapper.from_clause = [gquery_rvar] relctx.pull_path_namespace(target=wrapper, source=gquery_rvar, ctx=subctx) new_part_clause: List[pgast.BaseExpr] = [] for i, expr in enumerate(part_clause): path_id = stmt.groupby[i].path_id pathctx.put_path_value_var(gquery, path_id, expr, force=True, env=ctx.env) output_ref = pathctx.get_path_value_output(gquery, path_id, env=ctx.env) assert isinstance(output_ref, pgast.ColumnRef) new_part_clause.append( astutils.get_column(gquery_rvar, output_ref)) part_clause = new_part_clause first_val = pathctx.get_rvar_path_identity_var( gquery_rvar, stmt.subject.path_id, env=ctx.env) gquery = wrapper pathctx.put_path_bond(gquery, group_path_id) group_id = pgast.FuncCall( name=('first_value', ), args=[first_val], over=pgast.WindowDef(partition_clause=part_clause)) pathctx.put_path_identity_var(gquery, group_path_id, group_id, env=ctx.env) pathctx.put_path_value_var(gquery, group_path_id, group_id, env=ctx.env) group_cte = pgast.CommonTableExpr(query=gquery, name=ctx.env.aliases.get('g')) group_cte_rvar = relctx.rvar_for_rel(group_cte, ctx=ctx) # Generate another subquery contaning distinct values of # path expressions in BY. with ctx.subrel() as gvctx: gvquery = gvctx.rel relctx.include_rvar(gvquery, group_cte_rvar, path_id=group_path_id, ctx=gvctx) pathctx.put_path_bond(gvquery, group_path_id) for group_set in stmt.groupby: dispatch.visit(group_set, ctx=gvctx) path_id = group_set.path_id if path_id.is_objtype_path(): pathctx.put_path_bond(gvquery, path_id) gvquery.distinct_clause = [ pathctx.get_path_identity_var(gvquery, group_path_id, env=ctx.env) ] for path_id, aspect in list(gvquery.path_rvar_map): if path_id not in group_paths and path_id != group_path_id: gvquery.path_rvar_map.pop((path_id, aspect)) for path_id, aspect in list(gquery.path_rvar_map): if path_id in group_paths: gquery.path_rvar_map.pop((path_id, aspect)) gquery.path_namespace.pop((path_id, aspect), None) gquery.path_outputs.pop((path_id, aspect), None) groupval_cte = pgast.CommonTableExpr(query=gvquery, name=ctx.env.aliases.get('gv')) groupval_cte_rvar = relctx.rvar_for_rel(groupval_cte, ctx=ctx) o_stmt = stmt.result.expr assert isinstance(o_stmt, irast.SelectStmt) # process the result expression; with ctx.subrel() as selctx: selquery = selctx.rel outer_id = stmt.result.path_id inner_id = o_stmt.result.path_id relctx.include_specific_rvar(selquery, groupval_cte_rvar, group_path_id, aspects=['identity'], ctx=ctx) for path_id in group_paths: selctx.path_scope[path_id] = selquery pathctx.put_path_rvar(selquery, path_id, groupval_cte_rvar, aspect='value', env=ctx.env) selctx.group_by_rels = selctx.group_by_rels.copy() selctx.group_by_rels[group_path_id, stmt.subject.path_id] = \ group_cte selquery.view_path_id_map = {outer_id: inner_id} selquery.ctes.append(group_cte) sortoutputs = [] selquery.ctes.append(groupval_cte) clauses.compile_output(o_stmt.result, ctx=selctx) # The WHERE clause if o_stmt.where is not None: selquery.where_clause = astutils.extend_binop( selquery.where_clause, clauses.compile_filter_clause(o_stmt.where, o_stmt.where_card, ctx=selctx)) for ir_sortexpr in o_stmt.orderby: alias = ctx.env.aliases.get('s') sexpr = dispatch.compile(ir_sortexpr.expr, ctx=selctx) selquery.target_list.append( pgast.ResTarget(val=sexpr, name=alias)) sortoutputs.append(alias) if not gvquery.target_list: # No values were pulled from the group-values rel, # we must remove the DISTINCT clause to prevent # a syntax error. gvquery.distinct_clause[:] = [] query = ctx.rel result_rvar = relctx.rvar_for_rel(selquery, lateral=True, ctx=ctx) relctx.include_rvar(query, result_rvar, path_id=outer_id, ctx=ctx) for rt in selquery.target_list: if rt.name is None: rt.name = ctx.env.aliases.get('v') if rt.name not in sortoutputs: query.target_list.append( pgast.ResTarget(val=astutils.get_column( result_rvar, rt.name), name=rt.name)) for i, ir_oexpr in enumerate(o_stmt.orderby): sort_ref = astutils.get_column(result_rvar, sortoutputs[i]) sortexpr = pgast.SortBy(node=sort_ref, dir=ir_oexpr.direction, nulls=ir_oexpr.nones_order) query.sort_clause.append(sortexpr) # The OFFSET clause if o_stmt.offset: with ctx.new() as ctx1: ctx1.expr_exposed = False query.limit_offset = dispatch.compile(o_stmt.offset, ctx=ctx1) # The LIMIT clause if o_stmt.limit: with ctx.new() as ctx1: ctx1.expr_exposed = False query.limit_count = dispatch.compile(o_stmt.limit, ctx=ctx1) clauses.fini_stmt(query, ctx, parent_ctx) return query
def gen_dml_cte( ir_stmt: irast.MutatingStmt, *, range_rvar: Optional[pgast.RelRangeVar], typeref: irast.TypeRef, ctx: context.CompilerContextLevel, ) -> Tuple[pgast.CommonTableExpr, pgast.PathRangeVar]: target_ir_set = ir_stmt.subject target_path_id = target_ir_set.path_id dml_stmt: pgast.Query if isinstance(ir_stmt, irast.InsertStmt): dml_stmt = pgast.InsertStmt() elif isinstance(ir_stmt, irast.UpdateStmt): dml_stmt = pgast.UpdateStmt() elif isinstance(ir_stmt, irast.DeleteStmt): dml_stmt = pgast.DeleteStmt() else: raise AssertionError(f'unexpected DML IR: {ir_stmt!r}') dml_stmt.relation = relctx.range_for_typeref( typeref, target_path_id, for_mutation=True, common_parent=True, ctx=ctx, ) pathctx.put_path_value_rvar(dml_stmt, target_path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_source_rvar(dml_stmt, target_path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_bond(dml_stmt, target_path_id) dml_cte = pgast.CommonTableExpr(query=dml_stmt, name=ctx.env.aliases.get(hint='m')) if range_rvar is not None: relctx.pull_path_namespace(target=dml_stmt, source=range_rvar, ctx=ctx) # Auxiliary relations are always joined via the WHERE # clause due to the structure of the UPDATE/DELETE SQL statements. dml_stmt.where_clause = astutils.new_binop( lexpr=pgast.ColumnRef( name=[dml_stmt.relation.alias.aliasname, 'id']), op='=', rexpr=pathctx.get_rvar_path_identity_var(range_rvar, target_ir_set.path_id, env=ctx.env)) # UPDATE has "FROM", while DELETE has "USING". if isinstance(dml_stmt, pgast.UpdateStmt): dml_stmt.from_clause.append(range_rvar) elif isinstance(dml_stmt, pgast.DeleteStmt): dml_stmt.using_clause.append(range_rvar) # Due to the fact that DML statements are structured # as a flat list of CTEs instead of nested range vars, # the top level path scope must be empty. The necessary # range vars will be injected explicitly in all rels that # need them. ctx.path_scope.clear() pathctx.put_path_value_rvar(dml_stmt, target_path_id, dml_stmt.relation, env=ctx.env) pathctx.put_path_source_rvar(dml_stmt, target_path_id, dml_stmt.relation, env=ctx.env) dml_rvar = relctx.rvar_for_rel(dml_cte, typeref=typeref, ctx=ctx) return dml_cte, dml_rvar
def process_link_values( *, ir_stmt: irast.MutatingStmt, ir_expr: irast.Set, target_tab: Tuple[str, ...], col_data: Mapping[str, pgast.BaseExpr], dml_rvar: pgast.PathRangeVar, sources: Iterable[pgast.BaseRangeVar], props_only: bool, target_is_scalar: bool, iterator_cte: Optional[pgast.CommonTableExpr], ctx: context.CompilerContextLevel, ) -> Tuple[pgast.CommonTableExpr, List[str]]: """Unpack data from an update expression into a series of selects. :param ir_expr: IR of the INSERT/UPDATE body element. :param target_tab: The link table being updated. :param col_data: Expressions used to populate well-known columns of the link table such as `source` and `__type__`. :param sources: A list of relations which must be joined into the data query to resolve expressions in *col_data*. :param props_only: Whether this link update only touches link properties. :param target_is_scalar: Whether the link target is an ScalarType. :param iterator_cte: CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ with ctx.newscope() as newscope, newscope.newrel() as subrelctx: row_query = subrelctx.rel relctx.include_rvar(row_query, dml_rvar, path_id=ir_stmt.subject.path_id, ctx=subrelctx) subrelctx.path_scope[ir_stmt.subject.path_id] = row_query if iterator_cte is not None: iterator_rvar = relctx.rvar_for_rel(iterator_cte, lateral=True, ctx=subrelctx) relctx.include_rvar(row_query, iterator_rvar, path_id=iterator_cte.query.path_id, ctx=subrelctx) with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx: input_rel = input_rel_ctx.rel if iterator_cte is not None: input_rel_ctx.path_scope[iterator_cte.query.path_id] = \ row_query input_rel_ctx.expr_exposed = False input_rel_ctx.volatility_ref = pathctx.get_path_identity_var( row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env) dispatch.visit(ir_expr, ctx=input_rel_ctx) if (isinstance(ir_expr.expr, irast.Stmt) and ir_expr.expr.iterator_stmt is not None): # The link value is computaed by a FOR expression, # check if the statement is a DML statement, and if so, # pull the iterator scope so that link property expressions # have the correct context. inner_iterator_cte = None inner_iterator_path_id = ir_expr.expr.iterator_stmt.path_id for cte in input_rel_ctx.toplevel_stmt.ctes: if cte.query.path_id == inner_iterator_path_id: inner_iterator_cte = cte break if inner_iterator_cte is not None: target_rvar = pathctx.get_path_rvar( input_rel, ir_expr.path_id, aspect='identity', env=input_rel_ctx.env, ) pathctx.put_path_rvar( input_rel, inner_iterator_path_id, rvar=target_rvar, aspect='identity', env=input_rel_ctx.env, ) inner_iterator_rvar = relctx.rvar_for_rel( inner_iterator_cte, lateral=True, ctx=subrelctx) relctx.include_rvar( input_rel, inner_iterator_rvar, path_id=inner_iterator_path_id, ctx=subrelctx, ) input_rel_ctx.path_scope[inner_iterator_path_id] = ( input_rel) shape_tuple = None if ir_expr.shape: shape_tuple = shapecomp.compile_shape( ir_expr, [expr for expr, _ in ir_expr.shape], ctx=input_rel_ctx, ) for element in shape_tuple.elements: pathctx.put_path_var_if_not_exists(input_rel_ctx.rel, element.path_id, element.val, aspect='value', env=input_rel_ctx.env) input_stmt: pgast.Query = input_rel input_rvar = pgast.RangeSubselect( subquery=input_rel, lateral=True, alias=pgast.Alias(aliasname=ctx.env.aliases.get('val'))) source_data: Dict[str, pgast.BaseExpr] = {} if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None: # UNION input_stmt = input_stmt.rarg path_id = ir_expr.path_id if shape_tuple is not None: for element in shape_tuple.elements: if not element.path_id.is_linkprop_path(): continue rptr_name = element.path_id.rptr_name() assert rptr_name is not None colname = rptr_name.name val = pathctx.get_rvar_path_value_var(input_rvar, element.path_id, env=ctx.env) source_data.setdefault(colname, val) else: if target_is_scalar: target_ref = pathctx.get_rvar_path_value_var(input_rvar, path_id, env=ctx.env) else: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref if not target_is_scalar and 'target' not in source_data: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['target'] = target_ref specified_cols = [] for col, expr in collections.ChainMap(col_data, source_data).items(): row_query.target_list.append(pgast.ResTarget(val=expr, name=col)) specified_cols.append(col) row_query.from_clause += list(sources) + [input_rvar] link_rows = pgast.CommonTableExpr(query=row_query, name=ctx.env.aliases.get(hint='r')) return link_rows, specified_cols