def visit_VerticalLoop( self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs ): for candidate in merge_candidates: declarations = [] statements = [] location_type = candidate[0].iteration_space.location_type first_index = node.horizontal_loops.index(candidate[0]) last_index = node.horizontal_loops.index(candidate[-1]) for loop in candidate: declarations += loop.stmt.declarations statements += loop.stmt.statements node.horizontal_loops[first_index : last_index + 1] = [ nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=declarations, statements=statements, location_type=location_type, ), iteration_space=node.horizontal_loops[first_index].iteration_space, ) ] return node
def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs): block = nir.BlockStmt(declarations=[], statements=[], location_type=node.stmt.location_type) stmt = self.visit( node.stmt, last_block=block, location_comprehensions={node.location.name: node.location}) block.statements.append(stmt) return nir.HorizontalLoop( stmt=block, location_type=node.location.chain.elements[0], )
def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, *, symtable, **kwargs): hloop_ctx = self.HorizontalLoopContext() self.visit( node.stmt, hloop_ctx=hloop_ctx, location_comprehensions={node.location.name: node.location}, symtable={ **symtable, **node.symtable_ }, **kwargs, ) return nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=hloop_ctx.declarations, statements=hloop_ctx.statements, location_type=node.location.location_type, ), iteration_space=nir.IterationSpace( name=node.location.name, location_type=node.location.location_type), )
def visit_HorizontalLoop( self, node: nir.HorizontalLoop, *, merge_groups: Dict[str, List[List[nir.NeighborLoop]]], **kwargs, ): assert id(node) in merge_groups groups: List[List[nir.NeighborLoop]] = merge_groups[id(node)] # the target neighbor loops where groups will be merged heads: List[str] = [id(group[0]) for group in groups] # mapping from id(NeighborLoop) to its target loop where it should be merged # (only for non targets) targets: Dict[int, nir.NeighborLoop] = {} # mapping from id(NeighborLoop) to the new initialization statements from the # merged loops to add in front of the neighbor loop targets_init: Dict[int, List[nir.AssignStmt]] = {} stmt_declarations = node.stmt.declarations stmt_statements = [] num_stmts = len(node.stmt.statements) for i, hl_stmt in enumerate(node.stmt.statements): # Traverse all the statements in the horizontal loop if isinstance(hl_stmt, nir.NeighborLoop): if id(hl_stmt) in heads: # If it is a target neighbor loop, create the dicts # from id(NeighborLoop) to this loop, for all the other # loops that will be merged into this current_group = groups[heads.index(id(hl_stmt))] target_n_loop = copy.deepcopy(current_group[0]) assert id(target_n_loop) == id(hl_stmt) for other_n_loop in current_group[1:]: targets[id(other_n_loop)] = target_n_loop stmt_statements.append(target_n_loop) else: # If it is a neighbor loop that should be merged, # merge body into target loop assert id(hl_stmt) in targets target_n_loop = targets[id(hl_stmt)] other_body: nir.BlockStmt = RenameSymbol.apply( hl_stmt.body, hl_stmt.name, target_n_loop.name) target_n_loop.body.declarations.extend( other_body.declarations) target_n_loop.body.statements.extend(other_body.statements) elif (isinstance(hl_stmt, nir.AssignStmt) and i < num_stmts - 1 and isinstance(node.stmt.statements[i + 1], nir.NeighborLoop) and id(node.stmt.statements[i + 1]) in targets): # If it is the initialization statement of a reduce neighbor loop, # save it for later in the list of inits associated to the target loop targets_init.setdefault( id(targets[id(node.stmt.statements[i + 1])]), []).append(hl_stmt) else: # Any other statement just passes stmt_statements.append(hl_stmt) # Move the reduce initialization statements of the merged neighbor loops # in front of the target neighbor loop in which they were merged into. i = 0 while i < len(stmt_statements): if id(stmt_statements[i]) in targets_init: offset = len(targets_init[id(stmt_statements[i])]) + 1 for init_stmt in targets_init[id(stmt_statements[i])]: stmt_statements.insert(i, init_stmt) i += offset else: i += 1 return nir.HorizontalLoop( iteration_space=node.iteration_space, stmt=nir.BlockStmt(declarations=stmt_declarations, statements=stmt_statements), )