예제 #1
0
def transform_basic_comparison(builder: IRBuilder, op: str, left: Value,
                               right: Value, line: int) -> Value:
    if (is_int_rprimitive(left.type) and is_int_rprimitive(right.type)
            and op in int_comparison_op_mapping.keys()):
        return builder.compare_tagged(left, right, op, line)
    negate = False
    if op == 'is not':
        op, negate = 'is', True
    elif op == 'not in':
        op, negate = 'in', True

    target = builder.binary_op(left, right, op, line)

    if negate:
        target = builder.unary_op(target, 'not', line)
    return target
예제 #2
0
def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
    # x in (...)/[...]
    # x not in (...)/[...]
    first_op = e.operators[0]
    if (first_op in ['in', 'not in'] and len(e.operators) == 1
            and isinstance(e.operands[1], (TupleExpr, ListExpr))):
        items = e.operands[1].items
        n_items = len(items)
        # x in y -> x == y[0] or ... or x == y[n]
        # x not in y -> x != y[0] and ... and x != y[n]
        # 16 is arbitrarily chosen to limit code size
        if 1 < n_items < 16:
            if e.operators[0] == 'in':
                bin_op = 'or'
                cmp_op = '=='
            else:
                bin_op = 'and'
                cmp_op = '!='
            lhs = e.operands[0]
            mypy_file = builder.graph['builtins'].tree
            assert mypy_file is not None
            bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node),
                                 [])
            exprs = []
            for item in items:
                expr = ComparisonExpr([cmp_op], [lhs, item])
                builder.types[expr] = bool_type
                exprs.append(expr)

            or_expr: Expression = exprs.pop(0)
            for expr in exprs:
                or_expr = OpExpr(bin_op, or_expr, expr)
                builder.types[or_expr] = bool_type
            return builder.accept(or_expr)
        # x in [y]/(y) -> x == y
        # x not in [y]/(y) -> x != y
        elif n_items == 1:
            if e.operators[0] == 'in':
                cmp_op = '=='
            else:
                cmp_op = '!='
            e.operators = [cmp_op]
            e.operands[1] = items[0]
        # x in []/() -> False
        # x not in []/() -> True
        elif n_items == 0:
            if e.operators[0] == 'in':
                return builder.false()
            else:
                return builder.true()

    if len(e.operators) == 1:
        # Special some common simple cases
        if first_op in ('is', 'is not'):
            right_expr = e.operands[1]
            if isinstance(right_expr,
                          NameExpr) and right_expr.fullname == 'builtins.None':
                # Special case 'is None' / 'is not None'.
                return translate_is_none(builder,
                                         e.operands[0],
                                         negated=first_op != 'is')
        left_expr = e.operands[0]
        if is_int_rprimitive(builder.node_type(left_expr)):
            right_expr = e.operands[1]
            if is_int_rprimitive(builder.node_type(right_expr)):
                if first_op in int_borrow_friendly_op:
                    borrow_left = is_borrow_friendly_expr(builder, right_expr)
                    left = builder.accept(left_expr, can_borrow=borrow_left)
                    right = builder.accept(right_expr, can_borrow=True)
                    return builder.compare_tagged(left, right, first_op,
                                                  e.line)

    # TODO: Don't produce an expression when used in conditional context
    # All of the trickiness here is due to support for chained conditionals
    # (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
    # `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
    expr_type = builder.node_type(e)

    # go(i, prev) generates code for `ei opi e{i+1} op{i+1} ... en`,
    # assuming that prev contains the value of `ei`.
    def go(i: int, prev: Value) -> Value:
        if i == len(e.operators) - 1:
            return transform_basic_comparison(
                builder, e.operators[i], prev,
                builder.accept(e.operands[i + 1]), e.line)

        next = builder.accept(e.operands[i + 1])
        return builder.builder.shortcircuit_helper(
            'and', expr_type, lambda: transform_basic_comparison(
                builder, e.operators[i], prev, next, e.line),
            lambda: go(i + 1, next), e.line)

    return go(0, builder.accept(e.operands[0]))