Exemple #1
0
    def visit_For(self, node):
        node.body = util.flatten([s for s in node.body])
        new_body = []
        for stmt in node.body:
            if isinstance(stmt, C.BinaryOp) and \
               isinstance(stmt.op, C.Op.Assign) and \
               isinstance(stmt.left, C.SymbolRef) and \
               (stmt.left.name.startswith("in_") or stmt.left.name.startswith("_input_")) and \
               not isinstance(stmt.right, C.FunctionCall):
                new_body.append(stmt)
                if isinstance(stmt.right, C.SymbolRef) and \
                  stmt.right.name in du_map:
                    du_map[stmt.left.name] = du_map[stmt.right.name]
                else:
                    du_map[stmt.left.name] = stmt.right

            elif isinstance(stmt, C.BinaryOp) and \
               isinstance(stmt.op, C.Op.Assign) and \
               isinstance(stmt.left, C.SymbolRef) and \
               isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \
               and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name):
                stmt = ReplaceSymbolRef().visit(stmt)
                new_body.append(stmt)
            elif isinstance(
                    stmt, C.FunctionCall
            ) and "_mm" in stmt.func.name and "_store" in stmt.func.name:
                stmt = ReplaceSymbolRef().visit(stmt)
                new_body.append(stmt)
            else:
                new_body.append(stmt)
        node.body = util.flatten([self.visit(s) for s in new_body])
        return node
Exemple #2
0
    def visit_For(self, node):
        node.body = util.flatten([s for s in node.body])
        new_body = []
        for stmt in node.body:
          if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name \
             and "_store" in stmt.func.name and inReplaceMapSource(stmt.args[0], self.replace_map):
                  
                  if isinstance(stmt.args[1], C.SymbolRef):
                    sym_arr_ref = extract_reference(stmt.args)  
                    store_in_du_map(sym_arr_ref)  
                    reg = stmt.args[1]
                    self.seen[reg.name] = None
                    new_body.append(stmt)

                  elif isinstance(stmt.args[1], C.FunctionCall) and "_mm" in stmt.func.name:
                      tmp = self._gen_register()
                      new_body.append(C.Assign(C.SymbolRef(tmp, get_simd_type()()), deepcopy(stmt.args[1])))
                      new_body.append(C.FunctionCall(C.SymbolRef(stmt.func.name),  [stmt.args[0],C.SymbolRef(tmp, None)]))
                      sym_arr_ref = extract_reference(C.FunctionCall(C.SymbolRef(stmt.func.name),  [stmt.args[0],C.SymbolRef(tmp, None)]).args)  
                      store_in_du_map(sym_arr_ref)
                  # if stmt.args[0].type:
                  #    self.seen[reg.name] = stmt.args[0].type     
                  #else:
                      self.seen[tmp] = None

          elif isinstance(stmt, C.BinaryOp) and \
             isinstance(stmt.op, C.Op.Assign) and \
             isinstance(stmt.left, C.SymbolRef) and \
             isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name and "_load" in stmt.right.func.name and inReplaceMapSink(stmt.right.args[0], self.replace_map): 
                  #print(stmt.right.args[0])                         
                  source = get_alias(stmt.right.args, self.replace_map)
                  #print(source)      
                  if (source is not None):
                    sym_arr_ref = construct_arr_reference(source, deepcopy(stmt.right.args))
                    if in_du_map(sym_arr_ref):
                       reg = get_register(sym_arr_ref)
                       #print(reg.name)   
                       if str(reg.name) in self.seen: 
                          #print(reg.name)  
                          sym_map[stmt.left.name] = reg
                       else:
                          new_body.append(stmt) 
                    else:
                       new_body.append(stmt)    
                  else:
                      new_body.append(stmt)
                            
          else:
              new_body.append(stmt)  
        node.body = util.flatten([self.visit(s) for s in new_body])
        return node
Exemple #3
0
    def visit_For(self, node):
        node.body = util.flatten([self.visit(s) for s in node.body])
        #TODO: assumption is that every loop starts with zero, not negative
        init = -1
        incr = -1
        test = -1
        if isinstance(node.init, C.BinaryOp) and \
           isinstance(node.init.op, C.Op.Assign) and \
           isinstance(node.init.left, C.SymbolRef) and \
           isinstance(node.init.right, C.Constant):
           init = node.init.right.value

        if isinstance(node.test, C.BinaryOp) and \
           isinstance(node.test.op, C.Op.Lt) and \
           isinstance(node.test.left, C.SymbolRef) and \
           isinstance(node.test.right, C.Constant):
           test = node.test.right.value

        if isinstance(node.incr, C.AugAssign) and \
           isinstance(node.incr.op, C.Op.Add) and \
           isinstance(node.incr.target, C.SymbolRef) and \
           isinstance(node.incr.value, C.Constant):
           incr = node.incr.value.value

        if init != -1 and test != -1 and incr != -1 and (init+incr) >= test:
          return [util.replace_symbol(node.init.left.name, C.Constant(init), s) for s in node.body]

        return node
Exemple #4
0
    def visit_For(self, node):
        node.body = util.flatten([self.visit(s) for s in node.body])
        #if node.init.left.name == "_neuron_index_0":
        # Don't lift out of outer most loop
        #    return node
        pre_stmts = []
        new_body = []
        loop_var = node.init.left.name

        for stmt in node.body:
            if isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
              isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name:
                hoist = True
                for arg in stmt.right.args:
                    if not (only_contains_symbol(arg, node.init.left.name,
                                                 self.fuse_map)):
                        hoist = False
                if hoist:
                    pre_stmts.append(stmt)
                else:
                    new_body.append(stmt)
            else:
                new_body.append(stmt)

        node.body = pre_stmts + new_body

        return node
Exemple #5
0
 def visit(self, node):
     """
     Support replacing nodes with a list of nodes by flattening `body`
     fields.
     """
     node = super().visit(node)
     if hasattr(node, 'body'):
         node.body = util.flatten(node.body)
     return node
Exemple #6
0
 def visit_For(self, node):
     node.body = [self.visit(s) for s in node.body]
     if node.init.left.name == self.target_var:
         node.incr = C.AddAssign(C.SymbolRef(self.target_var),
                                 C.Constant(self.factor))
         visitor = UnrollStatements(self.target_var, self.factor)
         node.body = util.flatten([visitor.visit(s) for s in node.body])
         if node.test.right.value == self.factor:
             return [
                 util.replace_symbol(node.init.left.name, C.Constant(0),
                                     s) for s in node.body
             ]
     return node
Exemple #7
0
    def visit(self, node):
        node = super().visit(node)
        if hasattr(node, 'body'):
            # [collector.visit(s) for s in node.body]
            newbody = []
            for s in node.body:
                if isinstance(s, C.BinaryOp) and isinstance(s.op, C.Op.Assign):
                    # Anand - needs more work 27th June 2017
                    if isinstance(s.left, C.SymbolRef) and (s.left.type is not None) and s.left.name in self.variables \
                         and s.left.name not in self.defs:
                        y = self._gen_register()

                        new_stmt = C.Assign(
                            C.SymbolRef(y,
                                        get_simd_type(s.left.type)()),
                            broadcast_ss(C.SymbolRef(s.left.name, None),
                                         s.left.type))
                        newbody.append(s)
                        newbody.append(new_stmt)
                        self.defs[s.left.name] = C.SymbolRef(y, None)
                        self.symbol_table[y] = get_simd_type(s.left.type)()
                    else:
                        for i in self.defs:
                            s = replace_symbol(i, self.defs[i], s)

                        if (isinstance(s.left.type,
                                       get_simd_type(ctypes.c_int()))
                                or isinstance(
                                    s.left.type, get_simd_type(
                                        ctypes.c_float()))) and isinstance(
                                            s.right, C.SymbolRef):
                            s.right = broadcast_ss(
                                C.SymbolRef(s.right.name, None), s.left.type)

                        elif isinstance(s.left, C.SymbolRef) and s.left.name in self.symbol_table and\
                             (isinstance(self.symbol_table[s.left.name], get_simd_type(ctypes.c_int())) or isinstance(self.symbol_table[s.left.name], get_simd_type(ctypes.c_float()))) and isinstance(s.right, C.SymbolRef):
                            s.right = broadcast_ss(
                                C.SymbolRef(s.right.name, None),
                                self.symbol_table[s.left.name])

                        newbody.append(s)

                else:

                    for i in self.defs:
                        s = replace_symbol(i, self.defs[i], s)

                    newbody.append(s)
            node.body = util.flatten(newbody)
        return node
Exemple #8
0
 def visit(self, node):
     node = super().visit(node)
     if hasattr(node, 'body'):
         # [collector.visit(s) for s in node.body]
         new_body = []
         seen = {}
         stores = []
         collector = VectorLoadCollector()
         for s in node.body:
             collector.visit(s)
             for stmt in collector.loads.keys():
                 if stmt not in seen:
                     reg = self._gen_register()
                     load_node, number, func = collector.loads[stmt]
                     seen[stmt] = (reg, load_node, func)
                     self.sym[reg] = get_simd_type()()
                     new_body.append(
                         C.Assign(
                             C.SymbolRef(reg,
                                         get_simd_type()()),
                             C.FunctionCall(C.SymbolRef(func),
                                            [load_node])))
             if isinstance(
                     s, C.FunctionCall
             ) and "_mm" in s.func.name and "_store" in s.func.name:
                 if s.args[0].codegen() in seen:
                     stores.append((s.args[0], seen[s.args[0].codegen()][0],
                                    s.func.name))
                     s = C.Assign(C.SymbolRef(seen[s.args[0].codegen()][0]),
                                  s.args[1])
             for stmt in seen.keys():
                 reg, load_node, func = seen[stmt]
                 replacer = VectorLoadReplacer(
                     C.FunctionCall(C.SymbolRef(func),
                                    [load_node]).codegen(),
                     C.SymbolRef(reg))
                 s = replacer.visit(s)
             new_body.append(s)
         for target, value, name in stores:
             if "epi32" in name:
                 new_body.append(store_epi32(target, C.SymbolRef(value)))
             elif "ps" in name:
                 new_body.append(store_ps(target, C.SymbolRef(value)))
             else:
                 assert (false)
         node.body = util.flatten(new_body)
     return node
Exemple #9
0
 def visit_For(self, node):
     node.body = [self.visit(s) for s in node.body]
     if node.init.left.name == self.target_var:
         if self.unroll_type == 0:
             node.incr = C.AddAssign(C.SymbolRef(self.target_var),
                                     C.Constant(self.factor))
             node.incr = C.AddAssign(C.SymbolRef(self.target_var),
                                     C.Constant(self.factor))
         elif self.unroll_type == 1:
             assert (node.test.right.value % self.factor == 0)
             node.test.right.value = node.test.right.value // self.factor
         else:
             assert (0)
         visitor = UnrollStatements(self.target_var, self.factor,
                                    self.unroll_type)
         node.body = util.flatten([visitor.visit(s) for s in node.body])
     return node
Exemple #10
0
    def visit_For(self, node):
        node.body = util.flatten([self.visit(s) for s in node.body])
        if node.init.left.name == self.enclosing_loop_var:
            new_body = []
            added_code = False
            prefetch_count = self.prefetch_count
            for stmt in node.body:
                new_body.append(stmt)
                if prefetch_count > 0 and isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                   isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \
                   and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name):
                    ast.dump(stmt.right.args[0])
                    if check_name(stmt.right.args[0], self.prefetch_field):
                        array_ref = deepcopy(stmt.right.args[0])
                        new_array_ref = self.rewrite_arg(array_ref)
                        where_to_add = new_body
                        prefetch_count -= 1
                        if node.init.left.name != self.prefetch_dest_loop:
                            where_to_add = HoistPrefetch.escape_body
                        added_code = True
                        where_to_add.append(
                            C.FunctionCall(
                                C.SymbolRef(prefetch_symbol_table[
                                    self.cacheline_hint]),
                                [
                                    C.Add(new_array_ref,
                                          C.SymbolRef("prefetch_offset_var"))
                                ]))
                        where_to_add.append(
                            C.Assign(
                                C.SymbolRef("prefetch_offset_var"),
                                C.Add(C.SymbolRef("prefetch_offset_var"),
                                      C.Constant(self.prefetch_offset))))

            if added_code:
                InitPrefetcher.init_body.append(
                    C.Assign(
                        C.SymbolRef("prefetch_offset_var", ctypes.c_int()),
                        C.Constant(0)))
            node.body = new_body
        return node
Exemple #11
0
 def visit_For(self, node):
     node.body = util.flatten([self.visit(s) for s in node.body])
     if node.init.left.name == self.enclosing_loop_var:
         new_body = []
         prefetch_count = self.prefetch_count
         for stmt in node.body:
             new_body.append(stmt)
             if prefetch_count > 0 and isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \
                and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name):
                 ast.dump(stmt.right.args[0])
                 if check_name(stmt.right.args[0], self.prefetch_field):
                     array_ref = deepcopy(stmt.right.args[0])
                     new_array_ref = self.rewrite_arg(array_ref)
                     prefetch_count -= 1
                     new_body.append(
                         C.FunctionCall(
                             C.SymbolRef(prefetch_symbol_table[
                                 self.cacheline_hint]), [new_array_ref]))
         node.body = new_body
     return node
Exemple #12
0
 def visit_For(self, node):
     """
     Converts iteration expressionsinto RangeDim semantic nodes
     """
     index = node.target
     if isinstance(index, ast.Name):
         self.index_vars.add(index.id)
     _range = node.iter
     if isinstance(_range, ast.Call) and _range.func.id == "eachindex":
         loopvars = []
         for dim in self.connections[0].mapping.shape:
             loopvars.append(self._gen_unique_variable())
         nodes = []
         for index, var in enumerate(loopvars):
             nodes.append(
                 ast.For(
                     ast.Name(var, ast.Store()),
                     ast.Call(
                         ast.Name("range_dim", ast.Load()),
                         [_range.args[0], ast.Num(index)], []), [], []))
         index_expr = ast.Tuple(
             [ast.Name(var, ast.Load()) for var in loopvars], ast.Load())
         nodes[-1].body = [
             util.replace_name(node.target, index_expr, s)
             for s in node.body
         ]
         for i in reversed(range(1, len(nodes))):
             nodes[i - 1].body.append(nodes[i])
         return self.visit(nodes[0])
     elif isinstance(_range, ast.Call) and _range.func.id in [
             "enumerate_dim", "range_dim"
     ]:
         node.body = [self.visit(s) for s in node.body]
         node.body = util.flatten(node.body)
         return RangeDim(node, self.connections[0].mapping,
                         self.connections[0].source)
     else:
         raise NotImplementedError(ast.dump(node))
Exemple #13
0
 def visit_For(self, node):
     node.body = util.flatten([self.visit(s) for s in node.body])
     if node.init.left.name == "_neuron_index_0":
         #Don't lift out of outer most loop
         return node
     pre_stmts = []
     new_body = []
     post_stmts = []
     loop_var = node.init.left.name
     deps = set()
     for stmt in node.body:
         # print(astor.dump_tree(stmt))
         if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name and \
             "_store" in stmt.func.name and \
             not util.contains_symbol(stmt, loop_var) and \
             not any(util.contains_symbol(stmt, dep) for dep in deps):
             post_stmts.append(stmt)
         elif isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                 isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name and \
                 not util.contains_symbol(stmt, loop_var) and \
                 not any(util.contains_symbol(stmt, dep) for dep in deps):
             pre_stmts.append(stmt)
         elif isinstance(stmt, C.BinaryOp) and \
              isinstance(stmt.op, C.Op.Assign) and \
              isinstance(stmt.left, C.SymbolRef) and \
              stmt.left.type is not None and \
                 not util.contains_symbol(stmt, loop_var) and \
                 not any(util.contains_symbol(stmt, dep) for dep in deps):
             pre_stmts.append(stmt)
         else:
             new_body.append(stmt)
             if isinstance(stmt, C.BinaryOp) and \
                isinstance(stmt.op, C.Op.Assign) and \
                isinstance(stmt.left, C.SymbolRef) and \
                stmt.left.type is not None:
                 deps.add(stmt.left.name)
     node.body = new_body
     return pre_stmts + [node] + post_stmts
Exemple #14
0
 def visit_For(self, node):
     # FIXME: This should no longer happen implicitly, instead the user
     # should use swap loops to lift tiled loops
     if isinstance(node.iter, ast.Call) and node.iter.func.id == "range" and \
         (self.direction == "forward" and node.target.id == "_neuron_index_1_outer") or \
         (self.direction in ["backward", "update_internal"] and node.target.id == "_neuron_index_0"):
         new_body = []
         for statement in node.body:
             result = self.visit(statement)
             if len(self.tiled_loops) > 0:
                 curr_loop = self.tiled_loops[0]
                 new_body.append(curr_loop)
                 for loop in self.tiled_loops[1:]:
                     curr_loop.body = [loop]
                     curr_loop = loop
                 curr_loop.body = [result]
                 self.tiled_loops = []
             else:
                 new_body.append(result)
         node.body = new_body
         return node
     node.body = util.flatten([self.visit(s) for s in node.body])
     return node
Exemple #15
0
 def visit_For(self, node):
     node.body = util.flatten([self.visit(s) for s in node.body])
     if node.init.left.name == self.prefetch_init_loop:
         node.body = InitPrefetcher.init_body + node.body
     return node
Exemple #16
0
 def visit_For(self, node):
     node.body = util.flatten([self.visit(s) for s in node.body])
     if node.init.left.name == self.prefetch_dest_loop:
         node.body = HoistPrefetch.escape_body + node.body
     return node
Exemple #17
0
    def visit_For(self, node):
        node.body = [self.visit(s) for s in node.body]
        # node.body = util.flatten(node.body)
        if node.init.left.name == self.unroll_var:
            var = node.init.left.name
            factor, unroll_type = self.unroll_factor, self.unroll_type
            if unroll_type == 0:
                node.incr = C.AddAssign(C.SymbolRef(var), C.Constant(factor))
                node.incr = C.AddAssign(C.SymbolRef(var), C.Constant(factor))
            elif unroll_type == 1:
                assert (node.test.right.value % factor == 0)
                node.test.right.value = node.test.right.value // factor
            else:
                assert (0)
            '''
            UnrollStatementsNoJam.new_body={}
            
            visitor = UnrollStatementsNoJam(self.unroll_var, self.unroll_factor, self.unroll_type)
            
            node.body = util.flatten([visitor.visit(s) for s in node.body])
 

            '''
            #new_body = []
            #for i in range(1,factor):
            #    self.newbody[i] = []
            #for s in node.body:
            UnrollStatementsNoJam.new_body = {}
            for i in range(1, factor):
                UnrollStatementsNoJam.new_body[i] = []

            visitor = UnrollStatementsNoJam(self.unroll_var,
                                            self.unroll_factor,
                                            self.unroll_type)

            node = visitor.visit(node)
            for i in range(1, factor):
                for j in range(len(UnrollStatementsNoJam.new_body[i])):
                    node.body.append(UnrollStatementsNoJam.new_body[i][j])

            node.body = util.flatten(node.body)
            '''  
            if not isinstance(s, o.For):
                      
                      #visitor = UnrollStatementsNoJam(self.unroll_var, self.unroll_factor, self.unroll_type)
                      n = visitor.visit(s)
                      new_body.append(n)    
                      for j in range(1, factor):
                          for i in range(len(UnrollStatementsNoJam.new_body[j])):
                              self.newbody[j].append(util.flatten(UnrollStatementsNoJam.new_body[j][i]))
 
                else:
                    p = visitor.visit(s)
                    UnrollStatementsNoJam.new_body={}
                    n = [visitor.visit(t) for t in s.body]
                    new_body.append(p)            
                    for j in range(1, factor):
                          for i in range(len(UnrollStatementsNoJam.new_body[j])):
                              self.newbody[j].append(C.For(
                        C.Assign(C.SymbolRef(s.init.left.name, ctypes.c_int()), C.Constant(0)),
                        C.Lt(C.SymbolRef(s.init.left.name), C.Constant(s.test.right.value)),
                        C.AddAssign(C.SymbolRef(s.init.left.name), C.Constant(s.incr.value.value)),
                        util.flatten(UnrollStatementsNoJam.new_body[j][i])))
            for j in range(1, factor): 
                for i in range(len(self.newbody[j])):
                    new_body.append(self.newbody[j][i])


            node.body = util.flatten(new_body)
            #node.body = new_body
            '''

        return node
Exemple #18
0
 def visit(self, node):
     node = super().visit(node)
     if hasattr(node, "body"):
         node.body = util.flatten(node.body)
     return node
Exemple #19
0
 def visit_FunctionDecl(self, node):
     new_defn = util.flatten([self.visit(s) for s in node.defn])
     node.defn = new_defn
     return node