Ejemplo n.º 1
0
def compile_operator(qlexpr: qlast.Base, op_name: str,
                     qlargs: typing.List[qlast.Base], *,
                     ctx: context.ContextLevel) -> irast.Set:

    env = ctx.env
    schema = env.schema
    opers = schema.get_operators(op_name, module_aliases=ctx.modaliases)

    if opers is None:
        raise errors.QueryError(
            f'no operator matches the given name and argument types',
            context=qlexpr.context)

    args = []
    for ai, qlarg in enumerate(qlargs):
        with ctx.newscope(fenced=True) as fencectx:
            # We put on a SET OF fence preemptively in case this is
            # a SET OF arg, which we don't know yet due to polymorphic
            # matching.  We will remove it if necessary in `finalize_args()`.
            arg_ir = setgen.ensure_set(dispatch.compile(qlarg, ctx=fencectx),
                                       ctx=fencectx)

            arg_ir = setgen.scoped_set(setgen.ensure_stmt(arg_ir,
                                                          ctx=fencectx),
                                       ctx=fencectx)

        arg_type = inference.infer_type(arg_ir, ctx.env)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of operand '
                f'#{ai} of {op_name}',
                context=qlarg.context)

        args.append((arg_type, arg_ir))

    matched = None
    # Some 2-operand operators are special when their operands are
    # arrays or tuples.
    if len(args) == 2:
        coll_opers = None
        # If both of the args are arrays or tuples, potentially
        # compile the operator for them differently than for other
        # combinations.
        if args[0][0].is_tuple() and args[1][0].is_tuple():
            # Out of the candidate operators, find the ones that
            # correspond to tuples.
            coll_opers = [
                op for op in opers if all(
                    param.get_type(schema).is_tuple()
                    for param in op.get_params(schema).objects(schema))
            ]

        elif args[0][0].is_array() and args[1][0].is_array():
            # Out of the candidate operators, find the ones that
            # correspond to arrays.
            coll_opers = [
                op for op in opers if all(
                    param.get_type(schema).is_array()
                    for param in op.get_params(schema).objects(schema))
            ]

        # Proceed only if we have a special case of collection operators.
        if coll_opers:
            # Then check if they are recursive (i.e. validation must be
            # done recursively for the subtypes). We rely on the fact that
            # it is forbidden to define an operator that has both
            # recursive and non-recursive versions.
            if not coll_opers[0].get_recursive(schema):
                # The operator is non-recursive, so regular processing
                # is needed.
                matched = polyres.find_callable(coll_opers,
                                                args=args,
                                                kwargs={},
                                                ctx=ctx)

            else:
                # Ultimately the operator will be the same, regardless of the
                # specific operand types, as long as it passed validation, so
                # we just use the first operand type for the purpose of
                # finding the callable.
                matched = polyres.find_callable(coll_opers,
                                                args=[(args[0][0], args[0][1]),
                                                      (args[0][0], args[1][1])
                                                      ],
                                                kwargs={},
                                                ctx=ctx)

                # Now that we have an operator, we need to validate that it
                # can be applied to the tuple or array elements.
                submatched = validate_recursive_operator(opers,
                                                         args[0],
                                                         args[1],
                                                         ctx=ctx)

                if len(submatched) != 1:
                    # This is an error. We want the error message to
                    # reflect whether no matches were found or too
                    # many, so we preserve the submatches found for
                    # this purpose.
                    matched = submatched

    # No special handling match was necessary, find a normal match.
    if matched is None:
        matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx)

    in_polymorphic_func = (ctx.env.func_params is not None
                           and ctx.env.func_params.has_polymorphic(env.schema))

    in_abstract_constraint = (
        in_polymorphic_func
        and ctx.env.parent_object_type is s_constr.Constraint)

    if not in_polymorphic_func:
        matched = [
            call for call in matched
            if not call.func.get_is_abstract(env.schema)
        ]

    if len(matched) == 1:
        matched_call = matched[0]
    else:
        if len(args) == 2:
            ltype = args[0][0].material_type(env.schema)
            rtype = args[1][0].material_type(env.schema)

            types = (f'{ltype.get_displayname(env.schema)!r} and '
                     f'{rtype.get_displayname(env.schema)!r}')
        else:
            types = ', '.join(
                repr(a[0].material_type(env.schema).get_displayname(
                    env.schema)) for a in args)

        if not matched:
            hint = ('Consider using an explicit type cast or a conversion '
                    'function.')

            if op_name == 'std::IF':
                hint = (f"The IF and ELSE result clauses must be of "
                        f"compatible types, while the condition clause must "
                        f"be 'std::bool'. {hint}")

            raise errors.QueryError(
                f'operator {str(op_name)!r} cannot be applied to '
                f'operands of type {types}',
                hint=hint,
                context=qlexpr.context)
        elif len(matched) > 1:
            if in_abstract_constraint:
                matched_call = matched[0]
            else:
                detail = ', '.join(
                    f'`{m.func.get_verbosename(ctx.env.schema)}`'
                    for m in matched)
                raise errors.QueryError(
                    f'operator {str(op_name)!r} is ambiguous for '
                    f'operands of type {types}',
                    hint=f'Possible variants: {detail}.',
                    context=qlexpr.context)

    final_args, params_typemods = finalize_args(matched_call, ctx=ctx)

    oper = matched_call.func
    assert isinstance(oper, s_oper.Operator)
    env.schema_refs.add(oper)
    oper_name = oper.get_shortname(env.schema)

    matched_params = oper.get_params(env.schema)
    rtype = matched_call.return_type

    if oper_name in {'std::UNION', 'std::IF'} and rtype.is_object_type():
        # Special case for the UNION and IF operators, instead of common
        # parent type, we return a union type.
        if oper_name == 'std::UNION':
            larg, rarg = (a.expr for a in final_args)
        else:
            larg, _, rarg = (a.expr for a in final_args)

        left_type = setgen.get_set_type(larg,
                                        ctx=ctx).material_type(ctx.env.schema)
        right_type = setgen.get_set_type(rarg,
                                         ctx=ctx).material_type(ctx.env.schema)

        if left_type.issubclass(env.schema, right_type):
            rtype = right_type
        elif right_type.issubclass(env.schema, left_type):
            rtype = left_type
        else:
            env.schema, rtype = s_utils.get_union_type(env.schema,
                                                       [left_type, right_type])

    is_polymorphic = (any(
        p.get_type(env.schema).is_polymorphic(env.schema)
        for p in matched_params.objects(env.schema)) and oper.get_return_type(
            env.schema).is_polymorphic(env.schema))

    from_op = oper.get_from_operator(env.schema)
    sql_operator = None
    if (from_op is not None and oper.get_code(env.schema) is None
            and oper.get_from_function(env.schema) is None
            and not in_polymorphic_func):
        sql_operator = tuple(from_op)

    node = irast.OperatorCall(
        args=final_args,
        func_module_id=env.schema.get_global(s_mod.Module,
                                             oper_name.module).id,
        func_shortname=oper_name,
        func_polymorphic=is_polymorphic,
        func_sql_function=oper.get_from_function(env.schema),
        sql_operator=sql_operator,
        force_return_cast=oper.get_force_return_cast(env.schema),
        volatility=oper.get_volatility(env.schema),
        operator_kind=oper.get_operator_kind(env.schema),
        params_typemods=params_typemods,
        context=qlexpr.context,
        typeref=irtyputils.type_to_typeref(env.schema, rtype),
        typemod=oper.get_return_typemod(env.schema),
    )

    return setgen.ensure_set(node, typehint=rtype, ctx=ctx)
Ejemplo n.º 2
0
def try_fold_associative_binop(
        opcall: irast.OperatorCall, *,
        ctx: context.ContextLevel) -> typing.Optional[irast.Set]:

    # Let's check if we have (CONST + (OTHER_CONST + X))
    # tree, which can be optimized to ((CONST + OTHER_CONST) + X)

    op = opcall.func_shortname
    my_const = opcall.args[0].expr
    other_binop = opcall.args[1].expr
    folded = None

    if isinstance(other_binop.expr, irast.BaseConstant):
        my_const, other_binop = other_binop, my_const

    if (isinstance(my_const.expr, irast.BaseConstant)
            and isinstance(other_binop.expr, irast.OperatorCall)
            and other_binop.expr.func_shortname == op
            and other_binop.expr.operator_kind is ft.OperatorKind.INFIX):

        other_const = other_binop.expr.args[0].expr
        other_binop_node = other_binop.expr.args[1].expr

        if isinstance(other_binop_node.expr, irast.BaseConstant):
            other_binop_node, other_const = \
                other_const, other_binop_node

        if isinstance(other_const.expr, irast.BaseConstant):
            try:
                new_const = ireval.evaluate(
                    irast.OperatorCall(
                        args=[
                            irast.CallArg(expr=other_const, ),
                            irast.CallArg(expr=my_const, ),
                        ],
                        func_module_id=opcall.func_module_id,
                        func_shortname=op,
                        func_polymorphic=opcall.func_polymorphic,
                        func_sql_function=opcall.func_sql_function,
                        sql_operator=opcall.sql_operator,
                        force_return_cast=opcall.force_return_cast,
                        operator_kind=opcall.operator_kind,
                        params_typemods=opcall.params_typemods,
                        context=opcall.context,
                        typeref=opcall.typeref,
                        typemod=opcall.typemod,
                    ),
                    schema=ctx.env.schema,
                )
            except ireval.UnsupportedExpressionError:
                pass
            else:
                folded_binop = irast.OperatorCall(
                    args=[
                        irast.CallArg(expr=setgen.ensure_set(new_const,
                                                             ctx=ctx), ),
                        irast.CallArg(expr=other_binop_node, ),
                    ],
                    func_module_id=opcall.func_module_id,
                    func_shortname=op,
                    func_polymorphic=opcall.func_polymorphic,
                    func_sql_function=opcall.func_sql_function,
                    sql_operator=opcall.sql_operator,
                    force_return_cast=opcall.force_return_cast,
                    operator_kind=opcall.operator_kind,
                    params_typemods=opcall.params_typemods,
                    context=opcall.context,
                    typeref=opcall.typeref,
                    typemod=opcall.typemod,
                )

                folded = setgen.ensure_set(folded_binop, ctx=ctx)

    return folded
Ejemplo n.º 3
0
def compile_operator(
        qlexpr: qlast.Base, op_name: str, qlargs: List[qlast.Base], *,
        ctx: context.ContextLevel) -> irast.Set:

    env = ctx.env
    schema = env.schema
    opers = schema.get_operators(op_name, module_aliases=ctx.modaliases)

    if opers is None:
        raise errors.QueryError(
            f'no operator matches the given name and argument types',
            context=qlexpr.context)

    fq_op_name = next(iter(opers)).get_shortname(ctx.env.schema)
    conditional_args = CONDITIONAL_OPS.get(fq_op_name)

    arg_ctxs = {}
    args = []
    for ai, qlarg in enumerate(qlargs):
        with ctx.newscope(fenced=True) as fencectx:
            fencectx.path_log = []
            # We put on a SET OF fence preemptively in case this is
            # a SET OF arg, which we don't know yet due to polymorphic
            # matching.  We will remove it if necessary in `finalize_args()`.
            if conditional_args and ai in conditional_args:
                fencectx.in_conditional = qlexpr.context

            arg_ir = setgen.ensure_set(
                dispatch.compile(qlarg, ctx=fencectx),
                ctx=fencectx)

            arg_ir = setgen.scoped_set(
                setgen.ensure_stmt(arg_ir, ctx=fencectx),
                ctx=fencectx)

            arg_ctxs[arg_ir] = fencectx

        arg_type = inference.infer_type(arg_ir, ctx.env)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of operand '
                f'#{ai} of {op_name}',
                context=qlarg.context)

        args.append((arg_type, arg_ir))

    # Check if the operator is a derived operator, and if so,
    # find the origins.
    origin_op = opers[0].get_derivative_of(env.schema)
    derivative_op: Optional[s_oper.Operator]
    if origin_op is not None:
        # If this is a derived operator, there should be
        # exactly one form of it.  This is enforced at the DDL
        # level, but check again to be sure.
        if len(opers) > 1:
            raise errors.InternalServerError(
                f'more than one derived operator of the same name: {op_name}',
                context=qlarg.context)

        derivative_op = opers[0]
        opers = schema.get_operators(origin_op)
        if not opers:
            raise errors.InternalServerError(
                f'cannot find the origin operator for {op_name}',
                context=qlarg.context)
        actual_typemods = [
            param.get_typemod(schema)
            for param in derivative_op.get_params(schema).objects(schema)
        ]
    else:
        derivative_op = None
        actual_typemods = []

    matched = None
    # Some 2-operand operators are special when their operands are
    # arrays or tuples.
    if len(args) == 2:
        coll_opers = None
        # If both of the args are arrays or tuples, potentially
        # compile the operator for them differently than for other
        # combinations.
        if args[0][0].is_tuple(env.schema) and args[1][0].is_tuple(env.schema):
            # Out of the candidate operators, find the ones that
            # correspond to tuples.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_tuple(schema)
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        elif args[0][0].is_array() and args[1][0].is_array():
            # Out of the candidate operators, find the ones that
            # correspond to arrays.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_array()
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        # Proceed only if we have a special case of collection operators.
        if coll_opers:
            # Then check if they are recursive (i.e. validation must be
            # done recursively for the subtypes). We rely on the fact that
            # it is forbidden to define an operator that has both
            # recursive and non-recursive versions.
            if not coll_opers[0].get_recursive(schema):
                # The operator is non-recursive, so regular processing
                # is needed.
                matched = polyres.find_callable(
                    coll_opers, args=args, kwargs={}, ctx=ctx)

            else:
                # The recursive operators are usually defined as
                # being polymorphic on all parameters, and so this has
                # a side-effect of forcing both operands to be of
                # the same type (via casting) before the operator is
                # applied.  This might seem suboptmial, since there might
                # be a more specific operator for the types of the
                # elements, but the current version of Postgres
                # actually requires tuples and arrays to be of the
                # same type in comparison, so this behavior is actually
                # what we want.
                matched = polyres.find_callable(
                    coll_opers,
                    args=args,
                    kwargs={},
                    ctx=ctx,
                )

                # Now that we have an operator, we need to validate that it
                # can be applied to the tuple or array elements.
                submatched = validate_recursive_operator(
                    opers, args[0], args[1], ctx=ctx)

                if len(submatched) != 1:
                    # This is an error. We want the error message to
                    # reflect whether no matches were found or too
                    # many, so we preserve the submatches found for
                    # this purpose.
                    matched = submatched

    # No special handling match was necessary, find a normal match.
    if matched is None:
        matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx)

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(env.schema)
    )

    in_abstract_constraint = (
        in_polymorphic_func and
        ctx.env.options.schema_object_context is s_constr.Constraint
    )

    if not in_polymorphic_func:
        matched = [call for call in matched
                   if not call.func.get_abstract(env.schema)]

    if len(matched) == 1:
        matched_call = matched[0]
    else:
        if len(args) == 2:
            ltype = schemactx.get_material_type(args[0][0], ctx=ctx)
            rtype = schemactx.get_material_type(args[1][0], ctx=ctx)

            types = (
                f'{ltype.get_displayname(env.schema)!r} and '
                f'{rtype.get_displayname(env.schema)!r}')
        else:
            types = ', '.join(
                repr(
                    schemactx.get_material_type(
                        a[0], ctx=ctx).get_displayname(env.schema)
                ) for a in args
            )

        if not matched:
            hint = ('Consider using an explicit type cast or a conversion '
                    'function.')

            if op_name == 'std::IF':
                hint = (f"The IF and ELSE result clauses must be of "
                        f"compatible types, while the condition clause must "
                        f"be 'std::bool'. {hint}")
            elif op_name == '+':
                str_t = cast(s_scalars.ScalarType,
                             env.schema.get('std::str'))
                bytes_t = cast(s_scalars.ScalarType,
                               env.schema.get('std::bytes'))
                if (
                    (ltype.issubclass(env.schema, str_t) and
                        rtype.issubclass(env.schema, str_t)) or
                    (ltype.issubclass(env.schema, bytes_t) and
                        rtype.issubclass(env.schema, bytes_t)) or
                    (ltype.is_array() and rtype.is_array())
                ):
                    hint = 'Consider using the "++" operator for concatenation'

            raise errors.QueryError(
                f'operator {str(op_name)!r} cannot be applied to '
                f'operands of type {types}',
                hint=hint,
                context=qlexpr.context)
        elif len(matched) > 1:
            if in_abstract_constraint:
                matched_call = matched[0]
            else:
                detail = ', '.join(
                    f'`{m.func.get_verbosename(ctx.env.schema)}`'
                    for m in matched
                )
                raise errors.QueryError(
                    f'operator {str(op_name)!r} is ambiguous for '
                    f'operands of type {types}',
                    hint=f'Possible variants: {detail}.',
                    context=qlexpr.context)

    oper = matched_call.func
    assert isinstance(oper, s_oper.Operator)
    env.add_schema_ref(oper, expr=qlexpr)
    oper_name = oper.get_shortname(env.schema)
    str_oper_name = str(oper_name)

    matched_params = oper.get_params(env.schema)
    rtype = matched_call.return_type

    is_polymorphic = (
        any(p.get_type(env.schema).is_polymorphic(env.schema)
            for p in matched_params.objects(env.schema)) and
        rtype.is_polymorphic(env.schema)
    )

    final_args, params_typemods = finalize_args(
        matched_call,
        arg_ctxs=arg_ctxs,
        actual_typemods=actual_typemods,
        is_polymorphic=is_polymorphic,
        ctx=ctx,
    )

    if str_oper_name in {'std::UNION', 'std::IF'} and rtype.is_object_type():
        # Special case for the UNION and IF operators, instead of common
        # parent type, we return a union type.
        if str_oper_name == 'std::UNION':
            larg, rarg = (a.expr for a in final_args)
        else:
            larg, _, rarg = (a.expr for a in final_args)

        left_type = schemactx.get_material_type(
            setgen.get_set_type(larg, ctx=ctx),
            ctx=ctx,
        )
        right_type = schemactx.get_material_type(
            setgen.get_set_type(rarg, ctx=ctx),
            ctx=ctx,
        )

        if left_type.issubclass(env.schema, right_type):
            rtype = right_type
        elif right_type.issubclass(env.schema, left_type):
            rtype = left_type
        else:
            assert isinstance(left_type, s_types.InheritingType)
            assert isinstance(right_type, s_types.InheritingType)
            rtype = schemactx.get_union_type([left_type, right_type], ctx=ctx)

    from_op = oper.get_from_operator(env.schema)
    sql_operator = None
    if (from_op is not None and oper.get_code(env.schema) is None and
            oper.get_from_function(env.schema) is None and
            not in_polymorphic_func):
        sql_operator = tuple(from_op)

    origin_name: Optional[sn.QualName]
    origin_module_id: Optional[uuid.UUID]
    if derivative_op is not None:
        origin_name = oper_name
        origin_module_id = env.schema.get_global(
            s_mod.Module, origin_name.module).id
        oper_name = derivative_op.get_shortname(env.schema)
    else:
        origin_name = None
        origin_module_id = None

    node = irast.OperatorCall(
        args=final_args,
        func_shortname=oper_name,
        func_polymorphic=is_polymorphic,
        origin_name=origin_name,
        origin_module_id=origin_module_id,
        func_sql_function=oper.get_from_function(env.schema),
        sql_operator=sql_operator,
        force_return_cast=oper.get_force_return_cast(env.schema),
        volatility=oper.get_volatility(env.schema),
        operator_kind=oper.get_operator_kind(env.schema),
        params_typemods=params_typemods,
        context=qlexpr.context,
        typeref=typegen.type_to_typeref(rtype, env=env),
        typemod=oper.get_return_typemod(env.schema),
    )

    return setgen.ensure_set(node, typehint=rtype, ctx=ctx)