def visit_ScalarAccess(self, node: oir.ScalarAccess, *, symtable: Dict[str, Any], **kwargs: Any) -> cuir.ScalarAccess: if isinstance(symtable.get(node.name, None), oir.ScalarDecl): return cuir.FieldAccess(name=node.name, offset=common.CartesianOffset.zero(), dtype=node.dtype) return cuir.ScalarAccess(name=node.name, dtype=node.dtype)
def make_positional(self, axis: int) -> cuir.FieldAccess: axis_name = ["i", "j", "k"][axis] positional = self.positionals.setdefault( axis, cuir.Positional( name=self.new_symbol_name(f"axis_{axis_name}_index"), axis_name=axis_name), ) return cuir.FieldAccess( name=positional.name, offset=common.CartesianOffset.zero(), dtype=common.DataType.INT32, )
def visit_FieldAccess( self, node: oir.FieldAccess, *, ij_caches: Dict[str, cuir.IJCacheDecl], k_caches: Dict[str, cuir.KCacheDecl], ctx: "Context", **kwargs: Any, ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]: data_index = self.visit( node.data_index, ij_caches=ij_caches, k_caches=k_caches, ctx=ctx, **kwargs, ) offset = self.visit( node.offset, ij_caches=ij_caches, k_caches=k_caches, ctx=ctx, **kwargs, ) if node.name in ij_caches: return cuir.IJCacheAccess( name=ij_caches[node.name].name, offset=offset, dtype=node.dtype, data_index=data_index, ) if node.name in k_caches: return cuir.KCacheAccess( name=k_caches[node.name].name, offset=offset, dtype=node.dtype, data_index=data_index, ) ctx.accessed_fields.add(node.name) return cuir.FieldAccess( name=node.name, offset=offset, data_index=data_index, dtype=node.dtype, )
def visit_FieldAccess( self, node: oir.FieldAccess, *, ij_caches: Dict[str, cuir.IJCacheDecl], k_caches: Dict[str, cuir.KCacheDecl], accessed_fields: Set[str], **kwargs: Any, ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]: if node.name in ij_caches: return cuir.IJCacheAccess( name=ij_caches[node.name].name, offset=node.offset, dtype=node.dtype, ) if node.name in k_caches: return cuir.KCacheAccess(name=k_caches[node.name].name, offset=node.offset, dtype=node.dtype) accessed_fields.add(node.name) return cuir.FieldAccess(name=node.name, offset=node.offset, dtype=node.dtype)