def visit_For(self, node): node.body = util.flatten([s for s in node.body]) new_body = [] for stmt in node.body: if isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ (stmt.left.name.startswith("in_") or stmt.left.name.startswith("_input_")) and \ not isinstance(stmt.right, C.FunctionCall): new_body.append(stmt) if isinstance(stmt.right, C.SymbolRef) and \ stmt.right.name in du_map: du_map[stmt.left.name] = du_map[stmt.right.name] else: du_map[stmt.left.name] = stmt.right elif isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \ and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name): stmt = ReplaceSymbolRef().visit(stmt) new_body.append(stmt) elif isinstance( stmt, C.FunctionCall ) and "_mm" in stmt.func.name and "_store" in stmt.func.name: stmt = ReplaceSymbolRef().visit(stmt) new_body.append(stmt) else: new_body.append(stmt) node.body = util.flatten([self.visit(s) for s in new_body]) return node
def visit_For(self, node): node.body = util.flatten([s for s in node.body]) new_body = [] for stmt in node.body: if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name \ and "_store" in stmt.func.name and inReplaceMapSource(stmt.args[0], self.replace_map): if isinstance(stmt.args[1], C.SymbolRef): sym_arr_ref = extract_reference(stmt.args) store_in_du_map(sym_arr_ref) reg = stmt.args[1] self.seen[reg.name] = None new_body.append(stmt) elif isinstance(stmt.args[1], C.FunctionCall) and "_mm" in stmt.func.name: tmp = self._gen_register() new_body.append(C.Assign(C.SymbolRef(tmp, get_simd_type()()), deepcopy(stmt.args[1]))) new_body.append(C.FunctionCall(C.SymbolRef(stmt.func.name), [stmt.args[0],C.SymbolRef(tmp, None)])) sym_arr_ref = extract_reference(C.FunctionCall(C.SymbolRef(stmt.func.name), [stmt.args[0],C.SymbolRef(tmp, None)]).args) store_in_du_map(sym_arr_ref) # if stmt.args[0].type: # self.seen[reg.name] = stmt.args[0].type #else: self.seen[tmp] = None elif isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name and "_load" in stmt.right.func.name and inReplaceMapSink(stmt.right.args[0], self.replace_map): #print(stmt.right.args[0]) source = get_alias(stmt.right.args, self.replace_map) #print(source) if (source is not None): sym_arr_ref = construct_arr_reference(source, deepcopy(stmt.right.args)) if in_du_map(sym_arr_ref): reg = get_register(sym_arr_ref) #print(reg.name) if str(reg.name) in self.seen: #print(reg.name) sym_map[stmt.left.name] = reg else: new_body.append(stmt) else: new_body.append(stmt) else: new_body.append(stmt) else: new_body.append(stmt) node.body = util.flatten([self.visit(s) for s in new_body]) return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) #TODO: assumption is that every loop starts with zero, not negative init = -1 incr = -1 test = -1 if isinstance(node.init, C.BinaryOp) and \ isinstance(node.init.op, C.Op.Assign) and \ isinstance(node.init.left, C.SymbolRef) and \ isinstance(node.init.right, C.Constant): init = node.init.right.value if isinstance(node.test, C.BinaryOp) and \ isinstance(node.test.op, C.Op.Lt) and \ isinstance(node.test.left, C.SymbolRef) and \ isinstance(node.test.right, C.Constant): test = node.test.right.value if isinstance(node.incr, C.AugAssign) and \ isinstance(node.incr.op, C.Op.Add) and \ isinstance(node.incr.target, C.SymbolRef) and \ isinstance(node.incr.value, C.Constant): incr = node.incr.value.value if init != -1 and test != -1 and incr != -1 and (init+incr) >= test: return [util.replace_symbol(node.init.left.name, C.Constant(init), s) for s in node.body] return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) #if node.init.left.name == "_neuron_index_0": # Don't lift out of outer most loop # return node pre_stmts = [] new_body = [] loop_var = node.init.left.name for stmt in node.body: if isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name: hoist = True for arg in stmt.right.args: if not (only_contains_symbol(arg, node.init.left.name, self.fuse_map)): hoist = False if hoist: pre_stmts.append(stmt) else: new_body.append(stmt) else: new_body.append(stmt) node.body = pre_stmts + new_body return node
def visit(self, node): """ Support replacing nodes with a list of nodes by flattening `body` fields. """ node = super().visit(node) if hasattr(node, 'body'): node.body = util.flatten(node.body) return node
def visit_For(self, node): node.body = [self.visit(s) for s in node.body] if node.init.left.name == self.target_var: node.incr = C.AddAssign(C.SymbolRef(self.target_var), C.Constant(self.factor)) visitor = UnrollStatements(self.target_var, self.factor) node.body = util.flatten([visitor.visit(s) for s in node.body]) if node.test.right.value == self.factor: return [ util.replace_symbol(node.init.left.name, C.Constant(0), s) for s in node.body ] return node
def visit(self, node): node = super().visit(node) if hasattr(node, 'body'): # [collector.visit(s) for s in node.body] newbody = [] for s in node.body: if isinstance(s, C.BinaryOp) and isinstance(s.op, C.Op.Assign): # Anand - needs more work 27th June 2017 if isinstance(s.left, C.SymbolRef) and (s.left.type is not None) and s.left.name in self.variables \ and s.left.name not in self.defs: y = self._gen_register() new_stmt = C.Assign( C.SymbolRef(y, get_simd_type(s.left.type)()), broadcast_ss(C.SymbolRef(s.left.name, None), s.left.type)) newbody.append(s) newbody.append(new_stmt) self.defs[s.left.name] = C.SymbolRef(y, None) self.symbol_table[y] = get_simd_type(s.left.type)() else: for i in self.defs: s = replace_symbol(i, self.defs[i], s) if (isinstance(s.left.type, get_simd_type(ctypes.c_int())) or isinstance( s.left.type, get_simd_type( ctypes.c_float()))) and isinstance( s.right, C.SymbolRef): s.right = broadcast_ss( C.SymbolRef(s.right.name, None), s.left.type) elif isinstance(s.left, C.SymbolRef) and s.left.name in self.symbol_table and\ (isinstance(self.symbol_table[s.left.name], get_simd_type(ctypes.c_int())) or isinstance(self.symbol_table[s.left.name], get_simd_type(ctypes.c_float()))) and isinstance(s.right, C.SymbolRef): s.right = broadcast_ss( C.SymbolRef(s.right.name, None), self.symbol_table[s.left.name]) newbody.append(s) else: for i in self.defs: s = replace_symbol(i, self.defs[i], s) newbody.append(s) node.body = util.flatten(newbody) return node
def visit(self, node): node = super().visit(node) if hasattr(node, 'body'): # [collector.visit(s) for s in node.body] new_body = [] seen = {} stores = [] collector = VectorLoadCollector() for s in node.body: collector.visit(s) for stmt in collector.loads.keys(): if stmt not in seen: reg = self._gen_register() load_node, number, func = collector.loads[stmt] seen[stmt] = (reg, load_node, func) self.sym[reg] = get_simd_type()() new_body.append( C.Assign( C.SymbolRef(reg, get_simd_type()()), C.FunctionCall(C.SymbolRef(func), [load_node]))) if isinstance( s, C.FunctionCall ) and "_mm" in s.func.name and "_store" in s.func.name: if s.args[0].codegen() in seen: stores.append((s.args[0], seen[s.args[0].codegen()][0], s.func.name)) s = C.Assign(C.SymbolRef(seen[s.args[0].codegen()][0]), s.args[1]) for stmt in seen.keys(): reg, load_node, func = seen[stmt] replacer = VectorLoadReplacer( C.FunctionCall(C.SymbolRef(func), [load_node]).codegen(), C.SymbolRef(reg)) s = replacer.visit(s) new_body.append(s) for target, value, name in stores: if "epi32" in name: new_body.append(store_epi32(target, C.SymbolRef(value))) elif "ps" in name: new_body.append(store_ps(target, C.SymbolRef(value))) else: assert (false) node.body = util.flatten(new_body) return node
def visit_For(self, node): node.body = [self.visit(s) for s in node.body] if node.init.left.name == self.target_var: if self.unroll_type == 0: node.incr = C.AddAssign(C.SymbolRef(self.target_var), C.Constant(self.factor)) node.incr = C.AddAssign(C.SymbolRef(self.target_var), C.Constant(self.factor)) elif self.unroll_type == 1: assert (node.test.right.value % self.factor == 0) node.test.right.value = node.test.right.value // self.factor else: assert (0) visitor = UnrollStatements(self.target_var, self.factor, self.unroll_type) node.body = util.flatten([visitor.visit(s) for s in node.body]) return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == self.enclosing_loop_var: new_body = [] added_code = False prefetch_count = self.prefetch_count for stmt in node.body: new_body.append(stmt) if prefetch_count > 0 and isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \ and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name): ast.dump(stmt.right.args[0]) if check_name(stmt.right.args[0], self.prefetch_field): array_ref = deepcopy(stmt.right.args[0]) new_array_ref = self.rewrite_arg(array_ref) where_to_add = new_body prefetch_count -= 1 if node.init.left.name != self.prefetch_dest_loop: where_to_add = HoistPrefetch.escape_body added_code = True where_to_add.append( C.FunctionCall( C.SymbolRef(prefetch_symbol_table[ self.cacheline_hint]), [ C.Add(new_array_ref, C.SymbolRef("prefetch_offset_var")) ])) where_to_add.append( C.Assign( C.SymbolRef("prefetch_offset_var"), C.Add(C.SymbolRef("prefetch_offset_var"), C.Constant(self.prefetch_offset)))) if added_code: InitPrefetcher.init_body.append( C.Assign( C.SymbolRef("prefetch_offset_var", ctypes.c_int()), C.Constant(0))) node.body = new_body return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == self.enclosing_loop_var: new_body = [] prefetch_count = self.prefetch_count for stmt in node.body: new_body.append(stmt) if prefetch_count > 0 and isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and "_mm" in stmt.right.func.name \ and ("_load_" in stmt.right.func.name or "_set1" in stmt.right.func.name or "_broadcast" in stmt.right.func.name): ast.dump(stmt.right.args[0]) if check_name(stmt.right.args[0], self.prefetch_field): array_ref = deepcopy(stmt.right.args[0]) new_array_ref = self.rewrite_arg(array_ref) prefetch_count -= 1 new_body.append( C.FunctionCall( C.SymbolRef(prefetch_symbol_table[ self.cacheline_hint]), [new_array_ref])) node.body = new_body return node
def visit_For(self, node): """ Converts iteration expressionsinto RangeDim semantic nodes """ index = node.target if isinstance(index, ast.Name): self.index_vars.add(index.id) _range = node.iter if isinstance(_range, ast.Call) and _range.func.id == "eachindex": loopvars = [] for dim in self.connections[0].mapping.shape: loopvars.append(self._gen_unique_variable()) nodes = [] for index, var in enumerate(loopvars): nodes.append( ast.For( ast.Name(var, ast.Store()), ast.Call( ast.Name("range_dim", ast.Load()), [_range.args[0], ast.Num(index)], []), [], [])) index_expr = ast.Tuple( [ast.Name(var, ast.Load()) for var in loopvars], ast.Load()) nodes[-1].body = [ util.replace_name(node.target, index_expr, s) for s in node.body ] for i in reversed(range(1, len(nodes))): nodes[i - 1].body.append(nodes[i]) return self.visit(nodes[0]) elif isinstance(_range, ast.Call) and _range.func.id in [ "enumerate_dim", "range_dim" ]: node.body = [self.visit(s) for s in node.body] node.body = util.flatten(node.body) return RangeDim(node, self.connections[0].mapping, self.connections[0].source) else: raise NotImplementedError(ast.dump(node))
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == "_neuron_index_0": #Don't lift out of outer most loop return node pre_stmts = [] new_body = [] post_stmts = [] loop_var = node.init.left.name deps = set() for stmt in node.body: # print(astor.dump_tree(stmt)) if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name and \ "_store" in stmt.func.name and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): post_stmts.append(stmt) elif isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): pre_stmts.append(stmt) elif isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ stmt.left.type is not None and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): pre_stmts.append(stmt) else: new_body.append(stmt) if isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ stmt.left.type is not None: deps.add(stmt.left.name) node.body = new_body return pre_stmts + [node] + post_stmts
def visit_For(self, node): # FIXME: This should no longer happen implicitly, instead the user # should use swap loops to lift tiled loops if isinstance(node.iter, ast.Call) and node.iter.func.id == "range" and \ (self.direction == "forward" and node.target.id == "_neuron_index_1_outer") or \ (self.direction in ["backward", "update_internal"] and node.target.id == "_neuron_index_0"): new_body = [] for statement in node.body: result = self.visit(statement) if len(self.tiled_loops) > 0: curr_loop = self.tiled_loops[0] new_body.append(curr_loop) for loop in self.tiled_loops[1:]: curr_loop.body = [loop] curr_loop = loop curr_loop.body = [result] self.tiled_loops = [] else: new_body.append(result) node.body = new_body return node node.body = util.flatten([self.visit(s) for s in node.body]) return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == self.prefetch_init_loop: node.body = InitPrefetcher.init_body + node.body return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == self.prefetch_dest_loop: node.body = HoistPrefetch.escape_body + node.body return node
def visit_For(self, node): node.body = [self.visit(s) for s in node.body] # node.body = util.flatten(node.body) if node.init.left.name == self.unroll_var: var = node.init.left.name factor, unroll_type = self.unroll_factor, self.unroll_type if unroll_type == 0: node.incr = C.AddAssign(C.SymbolRef(var), C.Constant(factor)) node.incr = C.AddAssign(C.SymbolRef(var), C.Constant(factor)) elif unroll_type == 1: assert (node.test.right.value % factor == 0) node.test.right.value = node.test.right.value // factor else: assert (0) ''' UnrollStatementsNoJam.new_body={} visitor = UnrollStatementsNoJam(self.unroll_var, self.unroll_factor, self.unroll_type) node.body = util.flatten([visitor.visit(s) for s in node.body]) ''' #new_body = [] #for i in range(1,factor): # self.newbody[i] = [] #for s in node.body: UnrollStatementsNoJam.new_body = {} for i in range(1, factor): UnrollStatementsNoJam.new_body[i] = [] visitor = UnrollStatementsNoJam(self.unroll_var, self.unroll_factor, self.unroll_type) node = visitor.visit(node) for i in range(1, factor): for j in range(len(UnrollStatementsNoJam.new_body[i])): node.body.append(UnrollStatementsNoJam.new_body[i][j]) node.body = util.flatten(node.body) ''' if not isinstance(s, o.For): #visitor = UnrollStatementsNoJam(self.unroll_var, self.unroll_factor, self.unroll_type) n = visitor.visit(s) new_body.append(n) for j in range(1, factor): for i in range(len(UnrollStatementsNoJam.new_body[j])): self.newbody[j].append(util.flatten(UnrollStatementsNoJam.new_body[j][i])) else: p = visitor.visit(s) UnrollStatementsNoJam.new_body={} n = [visitor.visit(t) for t in s.body] new_body.append(p) for j in range(1, factor): for i in range(len(UnrollStatementsNoJam.new_body[j])): self.newbody[j].append(C.For( C.Assign(C.SymbolRef(s.init.left.name, ctypes.c_int()), C.Constant(0)), C.Lt(C.SymbolRef(s.init.left.name), C.Constant(s.test.right.value)), C.AddAssign(C.SymbolRef(s.init.left.name), C.Constant(s.incr.value.value)), util.flatten(UnrollStatementsNoJam.new_body[j][i]))) for j in range(1, factor): for i in range(len(self.newbody[j])): new_body.append(self.newbody[j][i]) node.body = util.flatten(new_body) #node.body = new_body ''' return node
def visit(self, node): node = super().visit(node) if hasattr(node, "body"): node.body = util.flatten(node.body) return node
def visit_FunctionDecl(self, node): new_defn = util.flatten([self.visit(s) for s in node.defn]) node.defn = new_defn return node