def make_horizontal_loop_with_init(field: Str): write_access = nir.FieldAccess( name=field, primary=no_extent, location_type=default_location, ) return ( nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=write_access, right=nir.Literal( value=common.BuiltInLiteral.ONE, vtype=default_vtype, location_type=default_location, ), ) ], ), location_type=default_location, ), write_access, )
def visit_VerticalLoop(self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs): for candidate in merge_candidates: declarations = [] statements = [] location_type = candidate[0].location_type first_index = node.horizontal_loops.index(candidate[0]) last_index = node.horizontal_loops.index(candidate[-1]) for loop in candidate: declarations += loop.stmt.declarations statements += loop.stmt.statements node.horizontal_loops[first_index:last_index + 1] = [ # noqa: E203 nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=declarations, statements=statements, location_type=location_type, ), location_type=location_type, ) ] return node
def make_horizontal_loop_with_copy(write: Str, read: Str, read_has_extent: Bool): write_access = nir.FieldAccess( name=write, primary=no_extent, location_type=default_location, ) read_access = nir.FieldAccess( name=read, primary=with_extent if read_has_extent else no_extent, location_type=default_location, ) return ( nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt(left=write_access, right=read_access) ], ), location_type=default_location, ), write_access, read_access, )
def make_block_stmt(stmts: List[nir.Stmt], declarations: List[nir.LocalVar]): return nir.BlockStmt( location_type=stmts[0].location_type if len(stmts) > 0 else common.LocationType.Vertex, statements=stmts, declarations=declarations, )
def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwargs): loc_comprehension = copy.deepcopy(kwargs["location_comprehensions"]) assert node.neighbors.name not in loc_comprehension loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_) last_block.declarations.append( nir.LocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, )) last_block.statements.append( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), right=nir.Literal( value=self.REDUCE_OP_INIT_VAL[node.op], location_type=node.location_type, vtype=common.DataType.FLOAT64, # TODO ), location_type=node.location_type, ), ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), right=nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), op=self.REDUCE_OP_TO_BINOP[node.op], right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), location_type=body_location, ), location_type=body_location, ) ], location_type=body_location, ) last_block.statements.append( nir.NeighborLoop( neighbors=self.visit(node.neighbors.chain), body=body, location_type=node.location_type, )) return nir.VarAccess(name=reduce_var_name, location_type=node.location_type) # TODO
def visit_HorizontalLoop(self, node: gtir.HorizontalLoop, **kwargs): block = nir.BlockStmt(declarations=[], statements=[], location_type=node.stmt.location_type) stmt = self.visit( node.stmt, last_block=block, location_comprehensions={node.location.name: node.location}) block.statements.append(stmt) return nir.HorizontalLoop( stmt=block, location_type=node.location.chain.elements[0], )
def make_empty_block_stmt(location_type: common.LocationType): return nir.BlockStmt(location_type=location_type, declarations=[], statements=[])