Exemple #1
0
 def _(self, node: ir.NOT):
     if not is_truth_test(node.operand):
         return node
     if isinstance(node.operand, ir.BinOp):
         op = node.operand.op
         if op == "==":
             left = node.operand.left
             right = node.operand.right
             return ir.BinOp(left, right, "!=")
         elif op == "!=":
             left = node.operand.left
             right = node.operand.right
             return ir.BinOp(left, right, "==")
         # >, >=, <, <= are not safe to invert if unordered operands
         # are present, particularly floating point NaNs.
         # While this started off assuming integer arithmetic, it may
         # be better to move this after typing, since some of this applies
         # equally or almost as well to floating point arithmetic.
     elif isinstance(node, ir.NOT):
         # remove double negation
         operand = node.operand.operand
         if not is_truth_test(operand):
             # If the unwrapped type doesn't already export a truth test
             # we need to indicate this explicitly.
             operand = ir.TRUTH(operand)
         return operand
     return node
Exemple #2
0
def rewrite_pow(expr):
    coeff = expr.right
    base = expr.left
    if coeff == ir.Zero:
        return ir.One
    elif base == ir.Zero:
        # checking for weird errors more than anything
        if coeff.constant:
            if operator.lt(coeff.value, 0):
                # this isn't intended to catch all edge cases, just an obvious
                # one that may come up after folding
                msg = f"raises 0 to a negative power {expr}."
                raise CompilerError(msg)
            else:
                return ir.Zero
    elif coeff == ir.One:
        return expr.left
    elif coeff == ir.IntConst(-1):
        op = "/=" if expr.in_place else "/"
        return ir.BinOp(ir.One, expr.left, op)
    elif coeff == ir.IntConst(-2):
        op = "/=" if expr.in_place else "/"
        return ir.BinOp(ir.One, ir.BinOp(expr.left, expr.left, "*"), op)
    elif coeff == ir.IntConst(2):
        op = "*=" if expr.in_place else "*"
        return ir.BinOp(expr.left, expr.left, op)
    elif coeff == ir.FloatConst(0.5):
        return ir.Call("sqrt", (expr.left,), ())
    else:
        return expr
Exemple #3
0
 def _(self, expr: ir.BinOp):
     left = self.lookup(expr.left)
     right = self.lookup(expr.right)
     if left.constant and right.constant:
         op = binops[expr.op]
         if op in ("<<", ">>", "<<=", ">>="):
             if not isinstance(right, ir.IntConst):
                 msg = f"Cannot safely evaluate shifts by non-integral amounts: {left.value}  {op} {right.value}."
                 raise ValueError(msg)
             elif operator.eq(right.value, 0):
                 msg = f"Shift by zero error: {left.value} {op} {right.value}"
                 raise ValueError(msg)
             elif operator.lt(right.value, 0):
                 msg = f"Shift amount cannot be negative: {left.value} {op} {right.value}"
                 raise ValueError(msg)
         value = op(left.value, right.value)
         return wrap_constant(value)
     else:
         # It's not possible to always place constants on the right or left due to
         # non-commutative operators, but it's okay to standardize ordering of multiplication
         # and addition with a single constant.
         if is_addition(expr):
             if left.constant:
                 return ir.BinOp(right, left, expr.op)
         elif is_multiplication(expr):
             if right.constant:
                 if not expr.in_place:
                     return ir.BinOp(right, left, expr.op)
         return ir.BinOp(left, right, expr.op)
Exemple #4
0
    def _(self, node: ir.BinOp):
        left = node.left
        right = node.right
        # if a constant expression shows up here, treat it as an error since
        # it's weirder to handle than it seems
        assert not (left.constant and right.constant)
        two = ir.IntConst(2)
        negative_one = ir.IntConst(-1)

        if is_pow(node):
            if right == ir.Zero:
                return ir.One
            elif right == ir.One:
                return left
            elif right == two:
                return ir.BinOp(left, left, "*=" if node.in_place else "*")
            # square roots shouldn't come up here, given the associative qualifier
        elif is_addition(node):
            if right == ir.Zero:
                return left
            elif equals_unary_negate(right):
                return ir.BinOp(left, right.operand, "-=" if node.in_place else "-")
            elif equals_unary_negate(left):
                assert not node.in_place
                return ir.BinOp(right, left.operand, "-")
        elif is_subtraction(node):
            if left == ir.Zero:
                return ir.UnaryOp(right, "-")
            elif right == ir.Zero:
                return left
            elif equals_unary_negate(right):
                # Todo: this is not entirely correct... as it may not be a unary node
                # need something like extract unary operand..
                return ir.BinOp(left, right.operand, "+=" if node.in_place else "+")
            elif equals_unary_negate(left):
                assert not node.in_place
                return ir.BinOp(right, left.operand, "+")
        elif is_division(node):
            if right == ir.Zero:
                msg = f"Divide by zero error in expression {node}."
                raise CompilerError(msg)
            elif node.op in ("//", "//="):
                # only safe to fold floor divide, ignore left == right since these might
                # be zero. Constant cases should be handled by the const folder.
                if left == ir.Zero or right == ir.One:
                    return left
        elif is_multiplication(node):
            if left == ir.Zero:
                return ir.Zero
            elif left == ir.One:
                return right
            elif left == negative_one:
                if equals_unary_negate(right):
                    # -(-something)) is safe in Python but possibly unsafe in a fixed width
                    # destination. Folding it should be considered safe.
                    return right.operand
                else:
                    return ir.UnaryOp(right, "-")
Exemple #5
0
def _compute_iter_count(diff, step):
    # Todo: may insert an extra round of constant folding here..
    if step == ir.Zero:
        msg = "Zero step loop iterator encountered."
        raise CompilerError(msg)
    elif step == ir.One:
        count = diff
    else:
        on_false = ir.BinOp(diffs, step, "//")
        modulo = ir.BinOp(diffs, step, "%")
        on_true = ir.BinOp(on_false, ir.One, "+")
        count = ir.Select(predicate=modulo, on_true=on_true, on_false=on_false)
    return count
Exemple #6
0
def p_AddExpr(p):
    """AddExpr : AddExpr '+' MulExpr
               | AddExpr '-' MulExpr
               | MulExpr"""
    # AddExpr '+' MulExpr
    if len(p) == 4 and p[2] == '+':
        p[0] = IR.BinOp(IR.BinOpType.PLUS, p[1], p[3])

    # AddExpr '+' MulExpr
    elif len(p) == 4 and p[2] == '-':
        p[0] = IR.BinOp(IR.BinOpType.MINUS, p[1], p[3])

    # MulExpr
    elif len(p) == 2:
        p[0] = p[1]
Exemple #7
0
 def visit_AugAssign(self, node: ast.AugAssign):
     target = self.visit(node.target)
     operand = self.visit(node.value)
     op = binary_in_place_ops.get(type(node.op))
     pos = extract_positional_info(node)
     assign = ir.Assign(target, ir.BinOp(target, operand, op), pos)
     self.body.append(assign)
Exemple #8
0
def p_MulExpr(p):
    """MulExpr : MulExpr '*' UnExpr
               | MulExpr '/' UnExpr
               | MulExpr '%' UnExpr
               | UnExpr"""
    # MulExpr '*' UnExpr
    if len(p) == 4 and p[2] == '*':
        p[0] = IR.BinOp(IR.BinOpType.MULT, p[1], p[3])

    # MulExpr '/' UnExpr
    elif len(p) == 4 and p[2] == '/':
        p[0] = IR.BinOp(IR.BinOpType.DIV, p[1], p[3])

    # MulExpr '%' UnExpr
    elif len(p) == 4 and p[2] == '%':
        p[0] = IR.BinOp(IR.BinOpType.MOD, p[1], p[3])

    # UnExpr
    elif len(p) == 2:
        p[0] = p[1]
 def visit_bin_op(self, node, target, *args):
     if node.op in ['+', '-', '*', '%'] or (node.op == '<'
                                            and target == ir.EXPR):
         assert target == ir.EXPR
         lhs = self.get_id()
         rhs = self.get_id()
         arg1, arg2 = self.reorder_for_bin_op(node.left, node.right, lhs,
                                              rhs)
         return [
             arg1,
             arg2,
             ir.BinOp(node.op, lhs, rhs, args[0]),
         ]
     if node.op == '<':
         assert target == ir.COND
         lhs = self.get_id()
         rhs = self.get_id()
         arg1, arg2 = self.reorder_for_bin_op(node.left, node.right, lhs,
                                              rhs)
         return [
             arg1,
             arg2,
             ir.CJumpLess(lhs, rhs, args[0], args[1]),
         ]
     if node.op == '&&':
         if target == ir.COND:
             lbl_second_arg = self.get_id()
             return [
                 self.visit(node.left, ir.COND, lbl_second_arg, args[1]),
                 ir.Label(lbl_second_arg),
                 self.visit(node.right, ir.COND, args[0], args[1]),
             ]
         else:
             lbl_second_arg = self.get_id()
             lbl_false = self.get_id()
             lbl_true = self.get_id()
             lbl_end = self.get_id()
             return [
                 self.visit(node.left, ir.COND, lbl_second_arg, lbl_false),
                 ir.Label(lbl_second_arg),
                 self.visit(node.right, ir.COND, lbl_true, lbl_false),
                 ir.Label(lbl_true),
                 ir.Const(1, args[0]),
                 ir.Jump(lbl_end),
                 ir.Label(lbl_false),
                 ir.Const(0, args[0]),
                 ir.Label(lbl_end),
             ]
     if node.op == '||':
         if target == ir.COND:
             lbl_second_arg = self.get_id()
             return [
                 self.visit(node.left, ir.COND, args[0], lbl_second_arg),
                 ir.Label(lbl_second_arg),
                 self.visit(node.right, ir.COND, args[0], args[1]),
             ]
         else:
             lbl_second_arg = self.get_id()
             lbl_false = self.get_id()
             lbl_true = self.get_id()
             lbl_end = self.get_id()
             return [
                 self.visit(node.left, ir.COND, lbl_true, lbl_second_arg),
                 ir.Label(lbl_second_arg),
                 self.visit(node.right, ir.COND, lbl_true, lbl_false),
                 ir.Label(lbl_false),
                 ir.Const(0, args[0]),
                 ir.Jump(lbl_end),
                 ir.Label(lbl_true),
                 ir.Const(1, args[0]),
                 ir.Label(lbl_end),
             ]
     assert False
Exemple #10
0
def make_single_index_loop(header: ir.ForLoop, symbols):
    """

        Make loop interval of the form (start, stop, step).

        This tries to find a safe method of calculation.

        This assumes (with runtime verification if necessary)
        that 'stop - start' will not overflow.

        References:
            LIVINSKII et. al, Random Testing for C and C++ Compilers with YARPGen
            Dietz et. al, Understanding Integer Overflow in C/C++
            Bachmann et. al, Chains of Recurrences - a method to expedite the evaluation of closed-form functions
            https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html
            https://developercommunity.visualstudio.com/t/please-implement-integer-overflow-detection/409051
            https://numpy.org/doc/stable/user/building.html

    """

    by_iterable = {}
    intervals = set()
    interval_from_iterable = IntervalBuilder()
    for _, iterable in unpack_iterated(header.target, header.iterable):
        interval = interval_from_iterable(iterable)
        by_iterable[iterable] = interval
        intervals.add(interval)

    # loop_interval = _find_shared_interval(intervals)
    loop_start, loop_stop, loop_step = _find_shared_interval(intervals)
    loop_expr = ir.AffineSeq(loop_start, loop_stop, loop_step)
    # Todo: this needs a default setting to avoid excessive casts
    loop_counter = symbols.make_unique_name_like("i", type_=tr.Int32)
    body = []
    pos = header.pos
    simplify_expr = arithmetic_folding()

    for target, iterable in unpack_iterated(header.target, header.iterable):
        (start, _, step) = by_iterable[iterable]
        assert step == loop_step
        assert (start == loop_start) or (loop_start == ir.Zero)
        if step == loop_step:
            if start == loop_start:
                index = loop_counter
            else:
                assert loop_start == ir.Zero
                index = ir.BinOp(loop_counter, start, "+")
        else:
            # loop counter must be normalized
            assert loop_start == ir.Zero
            assert loop_step == ir.One
            index = ir.BinOp(step, loop_counter, "*")
            if start != ir.Zero:
                index = ir.BinOp(start, index, "+")

        value = index if isinstance(iterable, ir.AffineSeq) else ir.Subscript(iterable, index)
        assign = ir.Assign(target, value, pos)
        body.append(assign)

    # Todo: this doesn't hoist initial setup
    body.extend(header.body)
    repl = ir.ForLoop(loop_counter, loop_expr, body, pos)
    return repl
Exemple #11
0
def _find_shared_interval(intervals):
    starts = set()
    stops = set()
    steps = set()
    for start, stop, step in intervals:
        starts.add(start)
        stops.add(stop)
        steps.add(step)

    # enumerate doesn't declare a bound, so it shows up as None
    stops.discard(None)
    simplify_min_max = MinMaxSimplifier()
    simplify_expr = arithmetic_folding()

    if len(steps) == 1:
        # If there's only one step size, we can
        # avoid explicitly computing iteration count
        step = steps.pop()
        step = simplify_expr(step)
        if len(starts) == 1:
            start = starts.pop()
            if len(stops) == 1:
                stop = stops.pop()
                simplify_expr(stop)
            else:
                stop = ir.MinReduction({simplify_expr(s) for s in stops})
                stop = simplify_min_max(stop)
            return start, stop, step
        elif len(stops) == 1:
            stop = stops.pop()
            stop = simplify_expr(stop)
            start = {simplify_expr(s) for s in starts}
            start = ir.MaxReduction(start)
            start = simplify_min_max(start)
            diff = ir.BinOp(stop, start, "-")
            diff = simplify_arith(diff)
            return ir.Zero, diff, step

    # collect steps to minimize division and modulo ops required
    by_step = defaultdict(set)
    by_diff = defaultdict(set)
    for start, stop, step in intervals:
        diff = ir.BinOp(stop, start, "-")
        diff = simplify_expr(diff)
        by_step[step].add(diff)

    for step, diffs in by_step.items():
        by_step[frozenset(diffs)].add(step)

    # combine steps if possible
    for diff, steps in by_diff.items():
        if len(steps) != 1:
            # remove combinable entries
            for step in steps:
                by_step.pop(step)
            steps = ir.MaxReduction(steps)
            steps = simplify_commutative_min_max(steps)
            by_step[steps].update(diff)

    by_step_refined = {}

    for step, diffs in by_step.items():
        diffs = ir.Min(diffs)
        diffs = simplify_commutative_min_max(diffs)
        by_step_refined[step] = diffs

    # Now compute explicit counts, since we don't use explicit dependency testing
    # this isn't a big deal

    counts = set()
    for step, diff in by_step_refined.items():
        count = _compute_iter_count(diff, step)
        count = simplify_expr(count)
        counts.add(count)
    counts = ir.Min(counts)
    counts = simplify_commutative_min_max(counts)
    return ir.Zero, counts, ir.One
Exemple #12
0
 def visit_BinOp(self, node: ast.BinOp) -> ir.BinOp:
     op = binary_ops.get(type(node.op))
     left = self.visit(node.left)
     right = self.visit(node.right)
     return ir.BinOp(left, right, op)