示例#1
0
 def visit_BinaryOp(self, node):
     if isinstance(node.op, C.Op.Assign):
         check = [
             util.contains_symbol(node.right, var)
             for var in list(self.unrolled_vars) + [self.target_var]
         ]
         if any(check):
             body = []
             if hasattr(node.left, 'type') and node.left.type is not None:
                 self.unrolled_vars.add(node.left.name)
             for i in range(self.factor):
                 stmt = deepcopy(node)
                 for var in self.unrolled_vars:
                     stmt = util.replace_symbol(
                         var, C.SymbolRef(var + "_" + str(i)), stmt)
                 if self.unroll_type == 0:
                     body.append(
                         util.replace_symbol(
                             self.target_var,
                             C.Add(C.SymbolRef(self.target_var),
                                   C.Constant(i)), stmt))
                 elif self.unroll_type == 1:
                     body.append(
                         util.replace_symbol(
                             self.target_var,
                             C.Add(
                                 C.Mul(C.Constant(self.factor),
                                       C.SymbolRef(self.target_var)),
                                 C.Constant(i)), stmt))
                 else:
                     assert (false)
             return body
     return node
示例#2
0
    def visit_If(self, node):
        check = [
            util.contains_symbol(node, var)
            for var in list(self.unrolled_vars) + [self.target_var]
        ]

        if any(check):
            body = []
            for i in range(self.factor):
                stmt = deepcopy(node)
                for var in self.unrolled_vars:
                    stmt = util.replace_symbol(var,
                                               C.SymbolRef(var + "_" + str(i)),
                                               stmt)
                if self.unroll_type == 0:
                    body.append(
                        util.replace_symbol(
                            self.target_var,
                            C.Add(C.SymbolRef(self.target_var), C.Constant(i)),
                            stmt))
                elif self.unroll_type == 1:
                    body.append(
                        util.replace_symbol(
                            self.target_var,
                            C.Add(
                                C.Mul(C.Constant(self.factor),
                                      C.SymbolRef(self.target_var)),
                                C.Constant(i)), stmt))
                else:
                    assert (false)

            return body
        return node
示例#3
0
    def test_add_zero(self):
        tree = C.Add(C.SymbolRef("a"), C.Constant(0))
        tree = ConstantFold().visit(tree)
        self.assertEqual(tree, C.SymbolRef("a"))

        tree = C.Add(C.Constant(0), C.SymbolRef("a"))
        tree = ConstantFold().visit(tree)
        self.assertEqual(tree, C.SymbolRef("a"))
示例#4
0
 def test_recursive_fold(self):
     tree = C.Assign(
         C.SymbolRef("c"),
         C.Add(C.Add(C.Constant(2), C.Constant(-2)),
               C.SymbolRef("b")))
     tree = ConstantFold().visit(tree)
     self.assertEqual(
         str(tree),
         str(C.Assign(C.SymbolRef("c"), C.SymbolRef("b"))))
示例#5
0
 def rewrite_arg(self, arg):
     if isinstance(arg, C.UnaryOp) and isinstance(
             arg.op, C.Op.Ref) and isinstance(
                 arg.arg, C.BinaryOp) and isinstance(
                     arg.arg.op, C.Op.ArrayRef):
         curr_node = arg.arg
     elif isinstance(arg, C.BinaryOp) and isinstance(arg.op, C.Op.ArrayRef):
         curr_node = arg
     else:
         curr_node = None
     idx = self.dim
     num_zeroes = self.prefetch_num_zeroes
     while (idx + 1 != 0):
         if num_zeroes > 0:
             curr_node.right = C.Constant(0)
             num_zeroes -= 1
         curr_node = curr_node.left
         idx += 1
     old_expr = curr_node.right
     #if isinstance(old_expr, C.BinaryOp) and isinstance(old_expr.op, C.Op.Add):
     #  old_expr = old_expr.left
     #new_expr = C.Add(old_expr, C.Mul(C.Add(C.SymbolRef(self.prefetch_loop_var), C.SymbolRef(self.prefetch_constant)), C.SymbolRef(self.prefetch_multiplier)))
     new_expr = C.Mul(
         C.Add(C.SymbolRef(self.prefetch_loop_var),
               C.SymbolRef(self.prefetch_constant)),
         C.SymbolRef(self.prefetch_multiplier))
     curr_node.right = new_expr
     if isinstance(arg, C.BinaryOp) and isinstance(arg.op, C.Op.ArrayRef):
         return C.Ref(arg)
     return arg
示例#6
0
 def gen_loop_index(self, loopvars, shape):
     curr = C.SymbolRef(loopvars[-1])
     for i in reversed(range(len(loopvars) - 1)):
         curr = C.Add(
             C.Mul(C.SymbolRef(loopvars[i]),
                   C.Constant(np.prod(shape[i + 1:]))), curr)
     return curr
示例#7
0
 def test_no_folding(self):
     trees = [
         C.Add(C.SymbolRef("a"), C.SymbolRef("b")),
         C.Sub(C.SymbolRef("a"), C.SymbolRef("b")),
         C.Mul(C.SymbolRef("a"), C.SymbolRef("b")),
         C.Div(C.SymbolRef("a"), C.SymbolRef("b")),
     ]
     for tree in trees:
         new_tree = ConstantFold().visit(tree)
         self.assertEqual(tree, new_tree)
示例#8
0
    def visit_For(self, node):
        node.body = util.flatten([self.visit(s) for s in node.body])
        if node.init.left.name == self.enclosing_loop_var:
            new_body = []
            added_code = False
            prefetch_count = self.prefetch_count
            for stmt in node.body:
                new_body.append(stmt)
                if prefetch_count > 0 and isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                   isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \
                   and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name):
                    ast.dump(stmt.right.args[0])
                    if check_name(stmt.right.args[0], self.prefetch_field):
                        array_ref = deepcopy(stmt.right.args[0])
                        new_array_ref = self.rewrite_arg(array_ref)
                        where_to_add = new_body
                        prefetch_count -= 1
                        if node.init.left.name != self.prefetch_dest_loop:
                            where_to_add = HoistPrefetch.escape_body
                        added_code = True
                        where_to_add.append(
                            C.FunctionCall(
                                C.SymbolRef(prefetch_symbol_table[
                                    self.cacheline_hint]),
                                [
                                    C.Add(new_array_ref,
                                          C.SymbolRef("prefetch_offset_var"))
                                ]))
                        where_to_add.append(
                            C.Assign(
                                C.SymbolRef("prefetch_offset_var"),
                                C.Add(C.SymbolRef("prefetch_offset_var"),
                                      C.Constant(self.prefetch_offset))))

            if added_code:
                InitPrefetcher.init_body.append(
                    C.Assign(
                        C.SymbolRef("prefetch_offset_var", ctypes.c_int()),
                        C.Constant(0)))
            node.body = new_body
        return node
示例#9
0
 def block_loop(self, node):
     loopvar = node.init.left.name
     loopvar += loopvar
     self.nest.insert(
         0,
         C.For(
             C.Assign(C.SymbolRef(loopvar, node.init.left.type),
                      node.init.right),
             C.Lt(C.SymbolRef(loopvar), node.test.right),
             C.AddAssign(C.SymbolRef(loopvar),
                         C.Constant(self.block_factor)), [None]))
     node.init.right = C.SymbolRef(loopvar)
     node.test.right = C.FunctionCall(C.SymbolRef("fmin"), [
         C.Add(C.SymbolRef(loopvar), C.Constant(self.block_factor)),
         node.test.right
     ])
示例#10
0
 def rewrite_arg(self, arg):
     if isinstance(arg, C.UnaryOp) and isinstance(
             arg.op, C.Op.Ref) and isinstance(
                 arg.arg, C.BinaryOp) and isinstance(
                     arg.arg.op, C.Op.ArrayRef):
         curr_node = arg.arg
     elif isinstance(arg, C.BinaryOp) and isinstance(arg.op, C.Op.ArrayRef):
         curr_node = arg
     else:
         curr_node = None
     idx = self.dim
     while (idx + 1 != 0):
         curr_node = curr_node.left
         idx += 1
     old_expr = curr_node.right
     new_expr = C.Add(old_expr, C.Constant(self.prefetch_constant))
     curr_node.right = new_expr
     if isinstance(arg, C.BinaryOp) and isinstance(arg.op, C.Op.ArrayRef):
         return C.Ref(arg)
     return arg
示例#11
0
 def test_add_constants(self):
     tree = C.Add(C.Constant(20), C.Constant(10))
     tree = ConstantFold().visit(tree)
     self.assertEqual(tree, C.Constant(30))
示例#12
0
    def visit_For(self, node):

        for j in range(1, self.factor):
            UnrollStatementsNoJam.new_body[j] = []

        # UnrollStatementsNoJam.new_body={}
        #for i in node.body:
        #new_body_cpy = deepcopy(UnrollStatementsNoJam.new_body)
        #node.body = [self.visit(s) for s in node.body]

        newbody = []

        for s in node.body:
            temp = deepcopy(UnrollStatementsNoJam.new_body)

            t = self.visit(s)
            stmt2 = deepcopy(t)
            stmt = deepcopy(t)
            if self.unroll_type == 0:
                s = util.replace_symbol(
                    self.target_var,
                    C.Add(C.SymbolRef(self.target_var), C.Constant(0)), stmt)
            else:
                s = util.replace_symbol(
                    self.target_var,
                    C.Add(
                        C.Mul(C.Constant(self.factor),
                              C.SymbolRef(self.target_var)), C.Constant(0)),
                    stmt)

            newbody.append(t)

            if not isinstance(t, C.For):
                for i in range(1, self.factor):
                    stmt = deepcopy(stmt2)

                    if self.unroll_type == 0:
                        if i in UnrollStatementsNoJam.new_body:
                            UnrollStatementsNoJam.new_body[i].append(
                                util.replace_symbol(
                                    self.target_var,
                                    C.Add(C.SymbolRef(self.target_var),
                                          C.Constant(i)), stmt))
                        else:
                            UnrollStatementsNoJam.new_body[i] = [
                                util.replace_symbol(
                                    self.target_var,
                                    C.Add(C.SymbolRef(self.target_var),
                                          C.Constant(i)), stmt)
                            ]
                    elif self.unroll_type == 1:
                        if i in UnrollStatementsNoJam.new_body:
                            UnrollStatementsNoJam.new_body[i].append(
                                util.replace_symbol(
                                    self.target_var,
                                    C.Add(
                                        C.Mul(C.Constant(self.factor),
                                              C.SymbolRef(self.target_var)),
                                        C.Constant(i)), stmt))
                        else:
                            UnrollStatementsNoJam.new_body[i] = [
                                util.replace_symbol(
                                    self.target_var,
                                    C.Add(
                                        C.Mul(C.Constant(self.factor),
                                              C.SymbolRef(self.target_var)),
                                        C.Constant(i)), stmt)
                            ]
                    else:
                        assert (false)

            else:
                var = t.init.left.name

                #if var != self.target_var:
                for j in range(1, self.factor):
                    temp[j].append(
                        C.For(
                            C.Assign(C.SymbolRef(var, ctypes.c_int()),
                                     C.Constant(0)),
                            C.Lt(C.SymbolRef(var),
                                 C.Constant(t.test.right.value)),
                            C.AddAssign(C.SymbolRef(var),
                                        C.Constant(t.incr.value.value)),
                            UnrollStatementsNoJam.new_body[j]))

                UnrollStatementsNoJam.new_body = deepcopy(temp)

        node.body = newbody
        return node
示例#13
0
    def visit_AugAssign(self, node):
        check = [
            util.contains_symbol(node.value, var)
            for var in list(self.unrolled_vars) + [self.target_var]
        ]
        if any(check):
            body = []
            if isinstance(node.target, C.SymbolRef):
                self.unrolled_vars.add(self._get_name(node.target.name))
                for i in range(self.factor):
                    stmt = deepcopy(node)
                    for var in self.unrolled_vars:
                        stmt = util.replace_symbol(
                            var, C.SymbolRef(var + "_" + str(i)), stmt)
                    #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt))
                    if self.unroll_type == 0:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(C.SymbolRef(self.target_var),
                                      C.Constant(i)), stmt))
                    elif self.unroll_type == 1:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(
                                    C.Mul(C.Constant(self.factor),
                                          C.SymbolRef(self.target_var)),
                                    C.Constant(i)), stmt))
                    else:
                        assert (false)

                return body
            elif isinstance(node.target, C.BinaryOp) and isinstance(
                    node.target.op, C.Op.ArrayRef):
                assert False
                for i in range(self.factor):
                    stmt = deepcopy(node)
                    for var in self.unrolled_vars:
                        stmt = util.replace_symbol(
                            var, C.SymbolRef(var + "_" + str(i)), stmt)
                    #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt))

                    if self.unroll_type == 0:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(C.SymbolRef(self.target_var),
                                      C.Constant(i)), stmt))
                    elif self.unroll_type == 1:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(
                                    C.Mul(C.Constant(self.factor),
                                          C.SymbolRef(self.target_var)),
                                    C.Constant(i)), stmt))
                    else:
                        assert (false)

                return body
            else:
                raise NotImplementedError()
        return node
示例#14
0
    def visit_RangeDim(self, node):
        iter = node.child_for.iter
        ensemble = node.ensemble
        ndim = node.mapping.ndim
        dim = iter.args[1].n
        offset = node.mapping.get_offset(dim)
        step = node.mapping.get_step(dim)
        length = len(node.mapping.shape[dim])
        if isinstance(iter, ast.Call) and iter.func.id == "range_dim":
            loop_var = node.child_for.target.id

            body = []
            body += [self.visit(s) for s in node.child_for.body]
            # FIXME: This check does not cover general cases
            #ANAND-special casing for LRN, needs refactoring
            if isinstance(self.ensemble, latte.ensemble.LRNEnsemble
                          ) and length < latte.config.SIMDWIDTH:
                if (
                        self.direction == "forward"
                        and "inputs" in self.ensemble.tiling_info
                        and any(dim == x[0]
                                for x in self.ensemble.tiling_info["inputs"])
                ) or (self.direction in ["backward", "update_internal"]
                      and "grad_inputs" in self.ensemble.tiling_info and any(
                          dim == x[0]
                          for x in self.ensemble.tiling_info["grad_inputs"])):
                    body = [
                        UpdateInputIndices(
                            loop_var + "_outer",
                            C.Div(
                                C.Add(
                                    C.SymbolRef(loop_var),
                                    C.SymbolRef(
                                        "_input_offset_{}_inner".format(dim +
                                                                        1))),
                                C.Constant(latte.config.SIMDWIDTH))).visit(s)
                        for s in body
                    ]
                    body = [
                        UpdateInputIndices(
                            "_input_offset_{}_inner".format(dim + 1),
                            C.Constant(0)).visit(s) for s in body
                    ]
                    body = [
                        UpdateInputIndices(
                            loop_var + "_inner",
                            C.Mod(
                                C.Add(
                                    C.SymbolRef(loop_var),
                                    C.SymbolRef(
                                        "_input_offset_{}_inner".format(dim +
                                                                        1))),
                                C.Constant(latte.config.SIMDWIDTH))).visit(s)
                        for s in body
                    ]
                    return C.For(
                        C.Assign(C.SymbolRef(loop_var, ctypes.c_int()),
                                 C.Constant(0)),
                        C.Lt(C.SymbolRef(loop_var), C.Constant(length)),
                        C.AddAssign(C.SymbolRef(loop_var), C.Constant(1)),
                        body,
                        # "unroll_and_jam({})".format(length)
                        # "unroll"
                    )
                else:
                    body = [
                        UpdateInputIndices(
                            loop_var,
                            C.Mul(C.SymbolRef(loop_var),
                                  C.Constant(step))).visit(s) for s in body
                    ]
                    return C.For(
                        C.Assign(C.SymbolRef(loop_var, ctypes.c_int()),
                                 C.Constant(0)),
                        C.Lt(C.SymbolRef(loop_var), C.Constant(length)),
                        C.AddAssign(C.SymbolRef(loop_var), C.Constant(1)),
                        body,
                        # "unroll_and_jam({})".format(length)
                        # "unroll"
                    )

            elif (
                    self.direction == "forward"
                    and "inputs" in self.ensemble.tiling_info
                    and any(dim == x[0]
                            for x in self.ensemble.tiling_info["inputs"])
            ) or (self.direction in ["backward", "update_internal"]
                  and "grad_inputs" in self.ensemble.tiling_info
                  and any(dim == x[0]
                          for x in self.ensemble.tiling_info["grad_inputs"])):
                outer_loop = C.For(
                    C.Assign(C.SymbolRef(loop_var + "_outer", ctypes.c_int()),
                             C.Constant(0)),
                    C.Lt(C.SymbolRef(loop_var + "_outer"),
                         C.Constant(length // latte.config.SIMDWIDTH)),
                    C.AddAssign(C.SymbolRef(loop_var + "_outer"),
                                C.Constant(1)), [])
                self.tiled_loops.append(outer_loop)
                if self.direction == "forward" and length < latte.config.SIMDWIDTH:
                    inner_loop = C.For(
                        C.Assign(
                            C.SymbolRef(loop_var + "_inner", ctypes.c_int()),
                            C.Constant(0)),
                        C.Lt(C.SymbolRef(loop_var + "_inner"),
                             C.Constant(length)),
                        C.AddAssign(C.SymbolRef(loop_var + "_inner"),
                                    C.Constant(1)),
                        body,
                    )
                else:
                    inner_loop = C.For(
                        C.Assign(
                            C.SymbolRef(loop_var + "_inner", ctypes.c_int()),
                            C.Constant(0)),
                        C.Lt(C.SymbolRef(loop_var + "_inner"),
                             C.Constant(latte.config.SIMDWIDTH)),
                        C.AddAssign(C.SymbolRef(loop_var + "_inner"),
                                    C.Constant(1)),
                        body,
                    )

                return inner_loop
            else:
                body = [
                    UpdateInputIndices(
                        loop_var, C.Mul(C.SymbolRef(loop_var),
                                        C.Constant(step))).visit(s)
                    for s in body
                ]
                return C.For(
                    C.Assign(C.SymbolRef(loop_var, ctypes.c_int()),
                             C.Constant(0)),
                    C.Lt(C.SymbolRef(loop_var), C.Constant(length)),
                    C.AddAssign(C.SymbolRef(loop_var), C.Constant(1)),
                    body,
                    # "unroll_and_jam({})".format(length)
                    # "unroll"
                )
        raise NotImplementedError()