def wrap_script_stmt( stmt: pgast.SelectStmt, ir_set: irast.Set, *, env: context.Environment, ) -> pgast.SelectStmt: subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw'))) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) count_val = pgast.FuncCall(name=('count', ), args=[pgast.ColumnRef(name=[stmt_res.name])]), result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=count_val, )], from_clause=[subrvar]) result.ctes = stmt.ctes result.argnames = stmt.argnames stmt.ctes = [] return result
def top_output_as_value( stmt: pgast.SelectStmt, ir_set: irast.Set, *, env: context.Environment) -> pgast.SelectStmt: """Finalize output serialization on the top level.""" if (env.output_format is context.OutputFormat.JSON and not env.expected_cardinality_one): # For JSON we just want to aggregate the whole thing # into a JSON array. return aggregate_json_output(stmt, ir_set, env=env) elif (env.output_format is context.OutputFormat.NATIVE and env.explicit_top_cast is not None): typecast = pgast.TypeCast( arg=stmt.target_list[0].val, type_name=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref( env.explicit_top_cast, persistent_tuples=True, ), ), ) stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=typecast, ) return stmt else: # JSON_ELEMENTS and BINARY don't require any wrapping return stmt
def aggregate_json_output(stmt: pgast.SelectStmt, ir_set: irast.Set, *, env: context.Environment) -> pgast.SelectStmt: subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw'))) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) new_val = pgast.CoalesceExpr(args=[ pgast.FuncCall(name=_get_json_func('agg', env=env), args=[pgast.ColumnRef(name=[stmt_res.name])]), pgast.StringConstant(val='[]') ]) result = pgast.SelectStmt(target_list=[pgast.ResTarget(val=new_val)], from_clause=[subrvar]) result.ctes = stmt.ctes result.argnames = stmt.argnames stmt.ctes = [] return result
def top_output_as_config_op( ir_set: irast.Set, stmt: pgast.SelectStmt, *, env: context.Environment) -> pgast.Query: assert isinstance(ir_set.expr, irast.ConfigCommand) if ir_set.expr.scope is qltypes.ConfigScope.SYSTEM: alias = env.aliases.get('cfg') subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias( aliasname=alias, ) ) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) result_row = pgast.RowExpr( args=[ pgast.StringConstant(val='ADD'), pgast.StringConstant(val=str(ir_set.expr.scope)), pgast.StringConstant(val=ir_set.expr.name), pgast.ColumnRef(name=[stmt_res.name]), ] ) array = pgast.FuncCall( name=('jsonb_build_array',), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=array, ), ], from_clause=[ subrvar, ], ) result.ctes = stmt.ctes result.argnames = stmt.argnames stmt.ctes = [] return result else: raise errors.InternalServerError( f'CONFIGURE {ir_set.expr.scope} INSERT is not supported')
def apply_volatility_ref(stmt: pgast.SelectStmt, *, ctx: context.CompilerContextLevel) -> None: for ref in ctx.volatility_ref: # Apply the volatility reference. # See the comment in process_set_as_subquery(). stmt.where_clause = astutils.extend_binop( stmt.where_clause, pgast.NullTest( arg=ref(), negated=True, ))
def rel_join(query: pgast.SelectStmt, right_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> None: condition = None for path_id in right_rvar.query.path_scope: lref = maybe_get_path_var(query, path_id, aspect='identity', ctx=ctx) if lref is None: lref = maybe_get_path_var(query, path_id, aspect='value', ctx=ctx) if lref is None: continue rref = pathctx.get_rvar_path_identity_var(right_rvar, path_id, env=ctx.env) assert isinstance(lref, pgast.ColumnRef) assert isinstance(rref, pgast.ColumnRef) path_cond = astutils.join_condition(lref, rref) condition = astutils.extend_binop(condition, path_cond) if condition is None: join_type = 'cross' else: join_type = 'inner' if not query.from_clause: query.from_clause.append(right_rvar) if condition is not None: query.where_clause = astutils.extend_binop(query.where_clause, condition) else: larg = query.from_clause[0] rarg = right_rvar query.from_clause[0] = pgast.JoinExpr(type=join_type, larg=larg, rarg=rarg, quals=condition)
def wrap_script_stmt( stmt: pgast.SelectStmt, *, suppress_all_output: bool = False, env: context.Environment, ) -> pgast.SelectStmt: subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias(aliasname=env.aliases.get('aggw'))) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) assert stmt_res.name is not None count_val = pgast.FuncCall(name=('count', ), args=[pgast.ColumnRef(name=[stmt_res.name])]) result = pgast.SelectStmt(target_list=[ pgast.ResTarget( val=count_val, name=stmt_res.name, ), ], from_clause=[ subrvar, ]) if suppress_all_output: subrvar = pgast.RangeSubselect( subquery=result, alias=pgast.Alias(aliasname=env.aliases.get('q'))) result = pgast.SelectStmt( target_list=[], from_clause=[ subrvar, ], where_clause=pgast.NullTest(arg=pgast.ColumnRef( name=[subrvar.alias.aliasname, stmt_res.name], ), ), ) result.ctes = stmt.ctes result.argnames = stmt.argnames stmt.ctes = [] return result
def anti_join( lhs: pgast.SelectStmt, rhs: pgast.SelectStmt, path_id: Optional[irast.PathId], *, ctx: context.CompilerContextLevel, ) -> None: """Filter elements out of the LHS that appear on the RHS""" if path_id: # grab the identity from the LHS and do an # anti-join against the RHS. src_ref = pathctx.get_path_identity_var(lhs, path_id=path_id, env=ctx.env) pathctx.get_path_identity_output(rhs, path_id=path_id, env=ctx.env) cond_expr: pgast.BaseExpr = astutils.new_binop(src_ref, rhs, 'NOT IN') else: # No path we care about. Just check existance. cond_expr = pgast.SubLink(type=pgast.SubLinkType.NOT_EXISTS, expr=rhs) lhs.where_clause = astutils.extend_binop(lhs.where_clause, cond_expr)
def semi_join(stmt: pgast.SelectStmt, ir_set: irast.Set, src_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: """Join an IR Set using semi-join.""" rptr = ir_set.rptr assert rptr is not None # Target set range. set_rvar = new_root_rvar(ir_set, ctx=ctx) ptrref = rptr.ptrref ptr_info = pg_types.get_ptrref_storage_info(ptrref, resolve_type=False, allow_missing=True) if ptr_info and ptr_info.table_type == 'ObjectType': if irtyputils.is_inbound_ptrref(ptrref): far_pid = ir_set.path_id.src_path() assert far_pid is not None else: far_pid = ir_set.path_id else: far_pid = ir_set.path_id # Link range. map_rvar = new_pointer_rvar(rptr, src_rvar=src_rvar, ctx=ctx) include_rvar(ctx.rel, map_rvar, path_id=ir_set.path_id.ptr_path(), ctx=ctx) tgt_ref = pathctx.get_rvar_path_identity_var(set_rvar, far_pid, env=ctx.env) pathctx.get_path_identity_output(ctx.rel, far_pid, env=ctx.env) cond = astutils.new_binop(tgt_ref, ctx.rel, 'IN') stmt.where_clause = astutils.extend_binop(stmt.where_clause, cond) return set_rvar