def visit_VerticalLoop( self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs ): for candidate in merge_candidates: declarations = [] statements = [] location_type = candidate[0].iteration_space.location_type first_index = node.horizontal_loops.index(candidate[0]) last_index = node.horizontal_loops.index(candidate[-1]) for loop in candidate: declarations += loop.stmt.declarations statements += loop.stmt.statements node.horizontal_loops[first_index : last_index + 1] = [ nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=declarations, statements=statements, location_type=location_type, ), iteration_space=node.horizontal_loops[first_index].iteration_space, ) ] return node
def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwargs): loc_comprehension = copy.deepcopy(kwargs["location_comprehensions"]) assert node.neighbors.name not in loc_comprehension loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_) last_block.declarations.append( nir.LocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, )) last_block.statements.append( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), right=nir.Literal( value=self.REDUCE_OP_INIT_VAL[node.op], location_type=node.location_type, vtype=common.DataType.FLOAT64, # TODO ), location_type=node.location_type, ), ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), right=nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), op=self.REDUCE_OP_TO_BINOP[node.op], right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), location_type=body_location, ), location_type=body_location, ) ], location_type=body_location, ) last_block.statements.append( nir.NeighborLoop( neighbors=self.visit(node.neighbors.chain), body=body, location_type=node.location_type, )) return nir.VarAccess(name=reduce_var_name, location_type=node.location_type) # TODO
def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs): block = nir.BlockStmt(declarations=[], statements=[], location_type=node.stmt.location_type) stmt = self.visit( node.stmt, last_block=block, location_comprehensions={node.location.name: node.location}) block.statements.append(stmt) return nir.HorizontalLoop( stmt=block, location_type=node.location.chain.elements[0], )
def visit_NeighborAssignStmt( self, node: gtir.NeighborAssignStmt, *, symtable, hloop_ctx: "HorizontalLoopContext", **kwargs, ): symtable = {**symtable, **node.symtable_} name = node.neighbors.name hloop_ctx.add_statement( nir.NeighborLoop( name=nir.NeighborLoopVar(name=name), connectivity=node.neighbors.of.name, body=nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.FieldAccess( name=node.left.name, primary=node.left.subscript[0].name, secondary=node.neighbors.name, location_type=node.location_type, ), right=self.visit( node.right, loop_var=name, symtable=symtable, hloop_ctx=hloop_ctx, **kwargs, ), location_type=node.location_type, ) ], location_type=node.location_type, ), location_type=node.location_type, ))
def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, *, symtable, **kwargs): hloop_ctx = self.HorizontalLoopContext() self.visit( node.stmt, hloop_ctx=hloop_ctx, location_comprehensions={node.location.name: node.location}, symtable={ **symtable, **node.symtable_ }, **kwargs, ) return nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=hloop_ctx.declarations, statements=hloop_ctx.statements, location_type=node.location.location_type, ), iteration_space=nir.IterationSpace( name=node.location.name, location_type=node.location.location_type), )
def visit_HorizontalLoop( self, node: nir.HorizontalLoop, *, merge_groups: Dict[str, List[List[nir.NeighborLoop]]], **kwargs, ): assert id(node) in merge_groups groups: List[List[nir.NeighborLoop]] = merge_groups[id(node)] # the target neighbor loops where groups will be merged heads: List[str] = [id(group[0]) for group in groups] # mapping from id(NeighborLoop) to its target loop where it should be merged # (only for non targets) targets: Dict[int, nir.NeighborLoop] = {} # mapping from id(NeighborLoop) to the new initialization statements from the # merged loops to add in front of the neighbor loop targets_init: Dict[int, List[nir.AssignStmt]] = {} stmt_declarations = node.stmt.declarations stmt_statements = [] num_stmts = len(node.stmt.statements) for i, hl_stmt in enumerate(node.stmt.statements): # Traverse all the statements in the horizontal loop if isinstance(hl_stmt, nir.NeighborLoop): if id(hl_stmt) in heads: # If it is a target neighbor loop, create the dicts # from id(NeighborLoop) to this loop, for all the other # loops that will be merged into this current_group = groups[heads.index(id(hl_stmt))] target_n_loop = copy.deepcopy(current_group[0]) assert id(target_n_loop) == id(hl_stmt) for other_n_loop in current_group[1:]: targets[id(other_n_loop)] = target_n_loop stmt_statements.append(target_n_loop) else: # If it is a neighbor loop that should be merged, # merge body into target loop assert id(hl_stmt) in targets target_n_loop = targets[id(hl_stmt)] other_body: nir.BlockStmt = RenameSymbol.apply( hl_stmt.body, hl_stmt.name, target_n_loop.name) target_n_loop.body.declarations.extend( other_body.declarations) target_n_loop.body.statements.extend(other_body.statements) elif (isinstance(hl_stmt, nir.AssignStmt) and i < num_stmts - 1 and isinstance(node.stmt.statements[i + 1], nir.NeighborLoop) and id(node.stmt.statements[i + 1]) in targets): # If it is the initialization statement of a reduce neighbor loop, # save it for later in the list of inits associated to the target loop targets_init.setdefault( id(targets[id(node.stmt.statements[i + 1])]), []).append(hl_stmt) else: # Any other statement just passes stmt_statements.append(hl_stmt) # Move the reduce initialization statements of the merged neighbor loops # in front of the target neighbor loop in which they were merged into. i = 0 while i < len(stmt_statements): if id(stmt_statements[i]) in targets_init: offset = len(targets_init[id(stmt_statements[i])]) + 1 for init_stmt in targets_init[id(stmt_statements[i])]: stmt_statements.insert(i, init_stmt) i += offset else: i += 1 return nir.HorizontalLoop( iteration_space=node.iteration_space, stmt=nir.BlockStmt(declarations=stmt_declarations, statements=stmt_statements), )
def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, hloop_ctx: "HorizontalLoopContext", **kwargs): connectivity_deref: gtir.Connectivity = kwargs["symtable"][ node.neighbors.of.name] reduce_var_name = "local" + str(id(node)) hloop_ctx.add_declaration( nir.LocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, )) hloop_ctx.add_statement( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), right=nir.Literal( value=self.REDUCE_OP_INIT_VAL[node.op], location_type=node.location_type, vtype=common.DataType.FLOAT64, # TODO ), location_type=node.location_type, ), ) body_location = connectivity_deref.secondary op = self.REDUCE_OP_TO_BINOP[node.op] if op == common.BuiltInLiteral.MIN_VALUE or op == common.BuiltInLiteral.MAX_VALUE: right = nir.NativeFuncCall( func=common.NativeFunction.MAX if op == common.BuiltInLiteral.MAX_VALUE else common.NativeFunction.MIN, args=[ nir.VarAccess(name=reduce_var_name, location_type=body_location), self.visit(node.operand, in_neighbor_loop=True, **kwargs), ], location_type=body_location, ) else: right = nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), op=op, right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), location_type=body_location, ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), right=right, location_type=body_location, ) ], location_type=body_location, ) hloop_ctx.add_statement( nir.NeighborLoop( name=nir.NeighborLoopVar(name=node.neighbors.name), connectivity=connectivity_deref.name, body=body, location_type=node.location_type, )) return nir.VarAccess(name=reduce_var_name, location_type=node.location_type) # TODO