def range_from_queryset( set_ops: typing.Sequence[typing.Tuple[str, pgast.BaseRelation]], scls: s_obj.Object, *, env: context.Environment) -> pgast.BaseRangeVar: if len(set_ops) > 1: # More than one class table, generate a UNION/EXCEPT clause. qry = pgast.SelectStmt( all=True, larg=set_ops[0][1] ) for op, rarg in set_ops[1:]: qry.op, qry.rarg = op, rarg qry = pgast.SelectStmt( all=True, larg=qry ) qry = qry.larg rvar = pgast.RangeSubselect( subquery=qry, alias=pgast.Alias( aliasname=env.aliases.get(scls.shortname.name) ) ) else: # Just one class table, so return it directly rvar = set_ops[0][1].from_clause[0] return rvar
def top_output_as_value(stmt: pgast.Query, *, env: context.Environment) -> pgast.Query: """Finalize output serialization on the top level.""" if env.output_format == context.OutputFormat.JSON: # For JSON we just want to aggregate the whole thing # into a JSON array. subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw'))) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res.name = env.aliases.get('v') new_val = pgast.FuncCall(name=('json_agg', ), args=[pgast.ColumnRef(name=[stmt_res.name])]) # XXX: nullability introspection is not reliable, # remove `True or` once it is. if True or stmt_res.val.nullable: new_val = pgast.CoalesceExpr( args=[new_val, pgast.Constant(val='[]')]) result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)], from_clause=[subrvar]) result.ctes = stmt.ctes stmt.ctes = [] return result else: return stmt
def rvar_for_rel( rel: pgast.BaseRelation, *, lateral: bool=False, colnames: typing.List[str]=[], env: context.Environment) -> pgast.BaseRangeVar: if isinstance(rel, pgast.Query): alias = env.aliases.get(rel.name or 'q') rvar = pgast.RangeSubselect( subquery=rel, alias=pgast.Alias(aliasname=alias, colnames=colnames), lateral=lateral, ) else: alias = env.aliases.get(rel.name) rvar = pgast.RangeVar( relation=rel, alias=pgast.Alias(aliasname=alias, colnames=colnames) ) return rvar
def process_link_values( ir_stmt, ir_expr, target_tab, tab_cols, col_data, dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *, ctx=context.CompilerContext) -> \ typing.Tuple[pgast.CommonTableExpr, typing.List[str]]: """Unpack data from an update expression into a series of selects. :param ir_expr: IR of the INSERT/UPDATE body element. :param target_tab: The link table being updated. :param tab_cols: A sequence of columns in the table being updated. :param col_data: Expressions used to populate well-known columns of the link table such as std::source and std::__type__. :param sources: A list of relations which must be joined into the data query to resolve expressions in *col_data*. :param props_only: Whether this link update only touches link properties. :param target_is_scalar: Whether the link target is an ScalarType. :param iterator_cte: CTE representing the iterator range in the FOR clause of the EdgeQL DML statement. """ with ctx.newscope() as newscope, newscope.newrel() as subrelctx: row_query = subrelctx.rel relctx.include_rvar(row_query, dml_rvar, path_id=ir_stmt.subject.path_id, ctx=subrelctx) subrelctx.path_scope[ir_stmt.subject.path_id] = row_query if iterator_cte is not None: iterator_rvar = dbobj.rvar_for_rel(iterator_cte, lateral=True, env=subrelctx.env) relctx.include_rvar(row_query, iterator_rvar, path_id=iterator_cte.query.path_id, ctx=subrelctx) with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx: input_rel = input_rel_ctx.rel if iterator_cte is not None: input_rel_ctx.path_scope[iterator_cte.query.path_id] = \ row_query input_rel_ctx.expr_exposed = False input_rel_ctx.shape_format = context.ShapeFormat.FLAT input_rel_ctx.volatility_ref = pathctx.get_path_identity_var( row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env) dispatch.compile(ir_expr, ctx=input_rel_ctx) input_stmt = input_rel input_rvar = pgast.RangeSubselect( subquery=input_rel, lateral=True, alias=pgast.Alias(aliasname=ctx.env.aliases.get('val'))) source_data = {} if input_stmt.op is not None: # UNION input_stmt = input_stmt.rarg path_id = ir_expr.path_id output = pathctx.get_path_value_output(input_stmt, path_id, env=ctx.env) if isinstance(output, pgast.TupleVar): for element in output.elements: name = element.path_id.rptr_name() if name is None: name = element.path_id[-1].name colname = common.edgedb_name_to_pg_name(name) source_data.setdefault(colname, dbobj.get_column(input_rvar, element.name)) else: if target_is_scalar: target_ref = pathctx.get_rvar_path_value_var(input_rvar, path_id, env=ctx.env) else: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['std::target'] = target_ref if not target_is_scalar and 'std::target' not in source_data: target_ref = pathctx.get_rvar_path_identity_var(input_rvar, path_id, env=ctx.env) source_data['std::target'] = target_ref specified_cols = [] for col in tab_cols: expr = col_data.get(col) if expr is None: expr = source_data.get(col) if expr is not None: row_query.target_list.append(pgast.ResTarget(val=expr, name=col)) specified_cols.append(col) row_query.from_clause += list(sources) + [input_rvar] link_rows = pgast.CommonTableExpr(query=row_query, name=ctx.env.aliases.get(hint='r')) return link_rows, specified_cols