def compile_operator(qlexpr: qlast.Base, op_name: str, qlargs: typing.List[qlast.Base], *, ctx: context.ContextLevel) -> irast.Set: env = ctx.env schema = env.schema opers = schema.get_operators(op_name, module_aliases=ctx.modaliases) if opers is None: raise errors.QueryError( f'no operator matches the given name and argument types', context=qlexpr.context) args = [] for ai, qlarg in enumerate(qlargs): with ctx.newscope(fenced=True) as fencectx: # We put on a SET OF fence preemptively in case this is # a SET OF arg, which we don't know yet due to polymorphic # matching. We will remove it if necessary in `finalize_args()`. arg_ir = setgen.ensure_set(dispatch.compile(qlarg, ctx=fencectx), ctx=fencectx) arg_ir = setgen.scoped_set(setgen.ensure_stmt(arg_ir, ctx=fencectx), ctx=fencectx) arg_type = inference.infer_type(arg_ir, ctx.env) if arg_type is None: raise errors.QueryError( f'could not resolve the type of operand ' f'#{ai} of {op_name}', context=qlarg.context) args.append((arg_type, arg_ir)) matched = None # Some 2-operand operators are special when their operands are # arrays or tuples. if len(args) == 2: coll_opers = None # If both of the args are arrays or tuples, potentially # compile the operator for them differently than for other # combinations. if args[0][0].is_tuple() and args[1][0].is_tuple(): # Out of the candidate operators, find the ones that # correspond to tuples. coll_opers = [ op for op in opers if all( param.get_type(schema).is_tuple() for param in op.get_params(schema).objects(schema)) ] elif args[0][0].is_array() and args[1][0].is_array(): # Out of the candidate operators, find the ones that # correspond to arrays. coll_opers = [ op for op in opers if all( param.get_type(schema).is_array() for param in op.get_params(schema).objects(schema)) ] # Proceed only if we have a special case of collection operators. if coll_opers: # Then check if they are recursive (i.e. validation must be # done recursively for the subtypes). We rely on the fact that # it is forbidden to define an operator that has both # recursive and non-recursive versions. if not coll_opers[0].get_recursive(schema): # The operator is non-recursive, so regular processing # is needed. matched = polyres.find_callable(coll_opers, args=args, kwargs={}, ctx=ctx) else: # Ultimately the operator will be the same, regardless of the # specific operand types, as long as it passed validation, so # we just use the first operand type for the purpose of # finding the callable. matched = polyres.find_callable(coll_opers, args=[(args[0][0], args[0][1]), (args[0][0], args[1][1]) ], kwargs={}, ctx=ctx) # Now that we have an operator, we need to validate that it # can be applied to the tuple or array elements. submatched = validate_recursive_operator(opers, args[0], args[1], ctx=ctx) if len(submatched) != 1: # This is an error. We want the error message to # reflect whether no matches were found or too # many, so we preserve the submatches found for # this purpose. matched = submatched # No special handling match was necessary, find a normal match. if matched is None: matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx) in_polymorphic_func = (ctx.env.func_params is not None and ctx.env.func_params.has_polymorphic(env.schema)) in_abstract_constraint = ( in_polymorphic_func and ctx.env.parent_object_type is s_constr.Constraint) if not in_polymorphic_func: matched = [ call for call in matched if not call.func.get_is_abstract(env.schema) ] if len(matched) == 1: matched_call = matched[0] else: if len(args) == 2: ltype = args[0][0].material_type(env.schema) rtype = args[1][0].material_type(env.schema) types = (f'{ltype.get_displayname(env.schema)!r} and ' f'{rtype.get_displayname(env.schema)!r}') else: types = ', '.join( repr(a[0].material_type(env.schema).get_displayname( env.schema)) for a in args) if not matched: hint = ('Consider using an explicit type cast or a conversion ' 'function.') if op_name == 'std::IF': hint = (f"The IF and ELSE result clauses must be of " f"compatible types, while the condition clause must " f"be 'std::bool'. {hint}") raise errors.QueryError( f'operator {str(op_name)!r} cannot be applied to ' f'operands of type {types}', hint=hint, context=qlexpr.context) elif len(matched) > 1: if in_abstract_constraint: matched_call = matched[0] else: detail = ', '.join( f'`{m.func.get_verbosename(ctx.env.schema)}`' for m in matched) raise errors.QueryError( f'operator {str(op_name)!r} is ambiguous for ' f'operands of type {types}', hint=f'Possible variants: {detail}.', context=qlexpr.context) final_args, params_typemods = finalize_args(matched_call, ctx=ctx) oper = matched_call.func assert isinstance(oper, s_oper.Operator) env.schema_refs.add(oper) oper_name = oper.get_shortname(env.schema) matched_params = oper.get_params(env.schema) rtype = matched_call.return_type if oper_name in {'std::UNION', 'std::IF'} and rtype.is_object_type(): # Special case for the UNION and IF operators, instead of common # parent type, we return a union type. if oper_name == 'std::UNION': larg, rarg = (a.expr for a in final_args) else: larg, _, rarg = (a.expr for a in final_args) left_type = setgen.get_set_type(larg, ctx=ctx).material_type(ctx.env.schema) right_type = setgen.get_set_type(rarg, ctx=ctx).material_type(ctx.env.schema) if left_type.issubclass(env.schema, right_type): rtype = right_type elif right_type.issubclass(env.schema, left_type): rtype = left_type else: env.schema, rtype = s_utils.get_union_type(env.schema, [left_type, right_type]) is_polymorphic = (any( p.get_type(env.schema).is_polymorphic(env.schema) for p in matched_params.objects(env.schema)) and oper.get_return_type( env.schema).is_polymorphic(env.schema)) from_op = oper.get_from_operator(env.schema) sql_operator = None if (from_op is not None and oper.get_code(env.schema) is None and oper.get_from_function(env.schema) is None and not in_polymorphic_func): sql_operator = tuple(from_op) node = irast.OperatorCall( args=final_args, func_module_id=env.schema.get_global(s_mod.Module, oper_name.module).id, func_shortname=oper_name, func_polymorphic=is_polymorphic, func_sql_function=oper.get_from_function(env.schema), sql_operator=sql_operator, force_return_cast=oper.get_force_return_cast(env.schema), volatility=oper.get_volatility(env.schema), operator_kind=oper.get_operator_kind(env.schema), params_typemods=params_typemods, context=qlexpr.context, typeref=irtyputils.type_to_typeref(env.schema, rtype), typemod=oper.get_return_typemod(env.schema), ) return setgen.ensure_set(node, typehint=rtype, ctx=ctx)
def try_fold_associative_binop( opcall: irast.OperatorCall, *, ctx: context.ContextLevel) -> typing.Optional[irast.Set]: # Let's check if we have (CONST + (OTHER_CONST + X)) # tree, which can be optimized to ((CONST + OTHER_CONST) + X) op = opcall.func_shortname my_const = opcall.args[0].expr other_binop = opcall.args[1].expr folded = None if isinstance(other_binop.expr, irast.BaseConstant): my_const, other_binop = other_binop, my_const if (isinstance(my_const.expr, irast.BaseConstant) and isinstance(other_binop.expr, irast.OperatorCall) and other_binop.expr.func_shortname == op and other_binop.expr.operator_kind is ft.OperatorKind.INFIX): other_const = other_binop.expr.args[0].expr other_binop_node = other_binop.expr.args[1].expr if isinstance(other_binop_node.expr, irast.BaseConstant): other_binop_node, other_const = \ other_const, other_binop_node if isinstance(other_const.expr, irast.BaseConstant): try: new_const = ireval.evaluate( irast.OperatorCall( args=[ irast.CallArg(expr=other_const, ), irast.CallArg(expr=my_const, ), ], func_module_id=opcall.func_module_id, func_shortname=op, func_polymorphic=opcall.func_polymorphic, func_sql_function=opcall.func_sql_function, sql_operator=opcall.sql_operator, force_return_cast=opcall.force_return_cast, operator_kind=opcall.operator_kind, params_typemods=opcall.params_typemods, context=opcall.context, typeref=opcall.typeref, typemod=opcall.typemod, ), schema=ctx.env.schema, ) except ireval.UnsupportedExpressionError: pass else: folded_binop = irast.OperatorCall( args=[ irast.CallArg(expr=setgen.ensure_set(new_const, ctx=ctx), ), irast.CallArg(expr=other_binop_node, ), ], func_module_id=opcall.func_module_id, func_shortname=op, func_polymorphic=opcall.func_polymorphic, func_sql_function=opcall.func_sql_function, sql_operator=opcall.sql_operator, force_return_cast=opcall.force_return_cast, operator_kind=opcall.operator_kind, params_typemods=opcall.params_typemods, context=opcall.context, typeref=opcall.typeref, typemod=opcall.typemod, ) folded = setgen.ensure_set(folded_binop, ctx=ctx) return folded
def compile_operator( qlexpr: qlast.Base, op_name: str, qlargs: List[qlast.Base], *, ctx: context.ContextLevel) -> irast.Set: env = ctx.env schema = env.schema opers = schema.get_operators(op_name, module_aliases=ctx.modaliases) if opers is None: raise errors.QueryError( f'no operator matches the given name and argument types', context=qlexpr.context) fq_op_name = next(iter(opers)).get_shortname(ctx.env.schema) conditional_args = CONDITIONAL_OPS.get(fq_op_name) arg_ctxs = {} args = [] for ai, qlarg in enumerate(qlargs): with ctx.newscope(fenced=True) as fencectx: fencectx.path_log = [] # We put on a SET OF fence preemptively in case this is # a SET OF arg, which we don't know yet due to polymorphic # matching. We will remove it if necessary in `finalize_args()`. if conditional_args and ai in conditional_args: fencectx.in_conditional = qlexpr.context arg_ir = setgen.ensure_set( dispatch.compile(qlarg, ctx=fencectx), ctx=fencectx) arg_ir = setgen.scoped_set( setgen.ensure_stmt(arg_ir, ctx=fencectx), ctx=fencectx) arg_ctxs[arg_ir] = fencectx arg_type = inference.infer_type(arg_ir, ctx.env) if arg_type is None: raise errors.QueryError( f'could not resolve the type of operand ' f'#{ai} of {op_name}', context=qlarg.context) args.append((arg_type, arg_ir)) # Check if the operator is a derived operator, and if so, # find the origins. origin_op = opers[0].get_derivative_of(env.schema) derivative_op: Optional[s_oper.Operator] if origin_op is not None: # If this is a derived operator, there should be # exactly one form of it. This is enforced at the DDL # level, but check again to be sure. if len(opers) > 1: raise errors.InternalServerError( f'more than one derived operator of the same name: {op_name}', context=qlarg.context) derivative_op = opers[0] opers = schema.get_operators(origin_op) if not opers: raise errors.InternalServerError( f'cannot find the origin operator for {op_name}', context=qlarg.context) actual_typemods = [ param.get_typemod(schema) for param in derivative_op.get_params(schema).objects(schema) ] else: derivative_op = None actual_typemods = [] matched = None # Some 2-operand operators are special when their operands are # arrays or tuples. if len(args) == 2: coll_opers = None # If both of the args are arrays or tuples, potentially # compile the operator for them differently than for other # combinations. if args[0][0].is_tuple(env.schema) and args[1][0].is_tuple(env.schema): # Out of the candidate operators, find the ones that # correspond to tuples. coll_opers = [ op for op in opers if all( param.get_type(schema).is_tuple(schema) for param in op.get_params(schema).objects(schema) ) ] elif args[0][0].is_array() and args[1][0].is_array(): # Out of the candidate operators, find the ones that # correspond to arrays. coll_opers = [ op for op in opers if all( param.get_type(schema).is_array() for param in op.get_params(schema).objects(schema) ) ] # Proceed only if we have a special case of collection operators. if coll_opers: # Then check if they are recursive (i.e. validation must be # done recursively for the subtypes). We rely on the fact that # it is forbidden to define an operator that has both # recursive and non-recursive versions. if not coll_opers[0].get_recursive(schema): # The operator is non-recursive, so regular processing # is needed. matched = polyres.find_callable( coll_opers, args=args, kwargs={}, ctx=ctx) else: # The recursive operators are usually defined as # being polymorphic on all parameters, and so this has # a side-effect of forcing both operands to be of # the same type (via casting) before the operator is # applied. This might seem suboptmial, since there might # be a more specific operator for the types of the # elements, but the current version of Postgres # actually requires tuples and arrays to be of the # same type in comparison, so this behavior is actually # what we want. matched = polyres.find_callable( coll_opers, args=args, kwargs={}, ctx=ctx, ) # Now that we have an operator, we need to validate that it # can be applied to the tuple or array elements. submatched = validate_recursive_operator( opers, args[0], args[1], ctx=ctx) if len(submatched) != 1: # This is an error. We want the error message to # reflect whether no matches were found or too # many, so we preserve the submatches found for # this purpose. matched = submatched # No special handling match was necessary, find a normal match. if matched is None: matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx) in_polymorphic_func = ( ctx.env.options.func_params is not None and ctx.env.options.func_params.has_polymorphic(env.schema) ) in_abstract_constraint = ( in_polymorphic_func and ctx.env.options.schema_object_context is s_constr.Constraint ) if not in_polymorphic_func: matched = [call for call in matched if not call.func.get_abstract(env.schema)] if len(matched) == 1: matched_call = matched[0] else: if len(args) == 2: ltype = schemactx.get_material_type(args[0][0], ctx=ctx) rtype = schemactx.get_material_type(args[1][0], ctx=ctx) types = ( f'{ltype.get_displayname(env.schema)!r} and ' f'{rtype.get_displayname(env.schema)!r}') else: types = ', '.join( repr( schemactx.get_material_type( a[0], ctx=ctx).get_displayname(env.schema) ) for a in args ) if not matched: hint = ('Consider using an explicit type cast or a conversion ' 'function.') if op_name == 'std::IF': hint = (f"The IF and ELSE result clauses must be of " f"compatible types, while the condition clause must " f"be 'std::bool'. {hint}") elif op_name == '+': str_t = cast(s_scalars.ScalarType, env.schema.get('std::str')) bytes_t = cast(s_scalars.ScalarType, env.schema.get('std::bytes')) if ( (ltype.issubclass(env.schema, str_t) and rtype.issubclass(env.schema, str_t)) or (ltype.issubclass(env.schema, bytes_t) and rtype.issubclass(env.schema, bytes_t)) or (ltype.is_array() and rtype.is_array()) ): hint = 'Consider using the "++" operator for concatenation' raise errors.QueryError( f'operator {str(op_name)!r} cannot be applied to ' f'operands of type {types}', hint=hint, context=qlexpr.context) elif len(matched) > 1: if in_abstract_constraint: matched_call = matched[0] else: detail = ', '.join( f'`{m.func.get_verbosename(ctx.env.schema)}`' for m in matched ) raise errors.QueryError( f'operator {str(op_name)!r} is ambiguous for ' f'operands of type {types}', hint=f'Possible variants: {detail}.', context=qlexpr.context) oper = matched_call.func assert isinstance(oper, s_oper.Operator) env.add_schema_ref(oper, expr=qlexpr) oper_name = oper.get_shortname(env.schema) str_oper_name = str(oper_name) matched_params = oper.get_params(env.schema) rtype = matched_call.return_type is_polymorphic = ( any(p.get_type(env.schema).is_polymorphic(env.schema) for p in matched_params.objects(env.schema)) and rtype.is_polymorphic(env.schema) ) final_args, params_typemods = finalize_args( matched_call, arg_ctxs=arg_ctxs, actual_typemods=actual_typemods, is_polymorphic=is_polymorphic, ctx=ctx, ) if str_oper_name in {'std::UNION', 'std::IF'} and rtype.is_object_type(): # Special case for the UNION and IF operators, instead of common # parent type, we return a union type. if str_oper_name == 'std::UNION': larg, rarg = (a.expr for a in final_args) else: larg, _, rarg = (a.expr for a in final_args) left_type = schemactx.get_material_type( setgen.get_set_type(larg, ctx=ctx), ctx=ctx, ) right_type = schemactx.get_material_type( setgen.get_set_type(rarg, ctx=ctx), ctx=ctx, ) if left_type.issubclass(env.schema, right_type): rtype = right_type elif right_type.issubclass(env.schema, left_type): rtype = left_type else: assert isinstance(left_type, s_types.InheritingType) assert isinstance(right_type, s_types.InheritingType) rtype = schemactx.get_union_type([left_type, right_type], ctx=ctx) from_op = oper.get_from_operator(env.schema) sql_operator = None if (from_op is not None and oper.get_code(env.schema) is None and oper.get_from_function(env.schema) is None and not in_polymorphic_func): sql_operator = tuple(from_op) origin_name: Optional[sn.QualName] origin_module_id: Optional[uuid.UUID] if derivative_op is not None: origin_name = oper_name origin_module_id = env.schema.get_global( s_mod.Module, origin_name.module).id oper_name = derivative_op.get_shortname(env.schema) else: origin_name = None origin_module_id = None node = irast.OperatorCall( args=final_args, func_shortname=oper_name, func_polymorphic=is_polymorphic, origin_name=origin_name, origin_module_id=origin_module_id, func_sql_function=oper.get_from_function(env.schema), sql_operator=sql_operator, force_return_cast=oper.get_force_return_cast(env.schema), volatility=oper.get_volatility(env.schema), operator_kind=oper.get_operator_kind(env.schema), params_typemods=params_typemods, context=qlexpr.context, typeref=typegen.type_to_typeref(rtype, env=env), typemod=oper.get_return_typemod(env.schema), ) return setgen.ensure_set(node, typehint=rtype, ctx=ctx)