def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwargs): loc_comprehension = copy.deepcopy(kwargs["location_comprehensions"]) assert node.neighbors.name not in loc_comprehension loc_comprehension[node.neighbors.name] = node.neighbors kwargs["location_comprehensions"] = loc_comprehension body_location = node.neighbors.chain.elements[-1] reduce_var_name = "local" + str(node.id_) last_block.declarations.append( nir.LocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, )) last_block.statements.append( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), right=nir.Literal( value=self.REDUCE_OP_INIT_VAL[node.op], location_type=node.location_type, vtype=common.DataType.FLOAT64, # TODO ), location_type=node.location_type, ), ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), right=nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), op=self.REDUCE_OP_TO_BINOP[node.op], right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), location_type=body_location, ), location_type=body_location, ) ], location_type=body_location, ) last_block.statements.append( nir.NeighborLoop( neighbors=self.visit(node.neighbors.chain), body=body, location_type=node.location_type, )) return nir.VarAccess(name=reduce_var_name, location_type=node.location_type) # TODO
def visit_AssignStmt(self, node: gtir.AssignStmt, *, hloop_ctx: "HorizontalLoopContext", **kwargs): hloop_ctx.add_statement( nir.AssignStmt( left=self.visit(node.left, **kwargs), right=self.visit(node.right, hloop_ctx=hloop_ctx, **kwargs), location_type=node.location_type, ))
def visit_NeighborAssignStmt( self, node: gtir.NeighborAssignStmt, *, symtable, hloop_ctx: "HorizontalLoopContext", **kwargs, ): symtable = {**symtable, **node.symtable_} name = node.neighbors.name hloop_ctx.add_statement( nir.NeighborLoop( name=nir.NeighborLoopVar(name=name), connectivity=node.neighbors.of.name, body=nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.FieldAccess( name=node.left.name, primary=node.left.subscript[0].name, secondary=node.neighbors.name, location_type=node.location_type, ), right=self.visit( node.right, loop_var=name, symtable=symtable, hloop_ctx=hloop_ctx, **kwargs, ), location_type=node.location_type, ) ], location_type=node.location_type, ), location_type=node.location_type, ))
def visit_AssignStmt(self, node: gtir.AssignStmt, **kwargs): return nir.AssignStmt( left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs), location_type=node.location_type, )
def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, hloop_ctx: "HorizontalLoopContext", **kwargs): connectivity_deref: gtir.Connectivity = kwargs["symtable"][ node.neighbors.of.name] reduce_var_name = "local" + str(id(node)) hloop_ctx.add_declaration( nir.LocalVar( name=reduce_var_name, vtype=common.DataType.FLOAT64, # TODO location_type=node.location_type, )) hloop_ctx.add_statement( nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=node.location_type), right=nir.Literal( value=self.REDUCE_OP_INIT_VAL[node.op], location_type=node.location_type, vtype=common.DataType.FLOAT64, # TODO ), location_type=node.location_type, ), ) body_location = connectivity_deref.secondary op = self.REDUCE_OP_TO_BINOP[node.op] if op == common.BuiltInLiteral.MIN_VALUE or op == common.BuiltInLiteral.MAX_VALUE: right = nir.NativeFuncCall( func=common.NativeFunction.MAX if op == common.BuiltInLiteral.MAX_VALUE else common.NativeFunction.MIN, args=[ nir.VarAccess(name=reduce_var_name, location_type=body_location), self.visit(node.operand, in_neighbor_loop=True, **kwargs), ], location_type=body_location, ) else: right = nir.BinaryOp( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), op=op, right=self.visit(node.operand, in_neighbor_loop=True, **kwargs), location_type=body_location, ) body = nir.BlockStmt( declarations=[], statements=[ nir.AssignStmt( left=nir.VarAccess(name=reduce_var_name, location_type=body_location), right=right, location_type=body_location, ) ], location_type=body_location, ) hloop_ctx.add_statement( nir.NeighborLoop( name=nir.NeighborLoopVar(name=node.neighbors.name), connectivity=connectivity_deref.name, body=body, location_type=node.location_type, )) return nir.VarAccess(name=reduce_var_name, location_type=node.location_type) # TODO