Ejemplo n.º 1
0
 def visit_ScalarAccess(self, node: oir.ScalarAccess, *,
                        symtable: Dict[str, Any],
                        **kwargs: Any) -> cuir.ScalarAccess:
     if isinstance(symtable.get(node.name, None), oir.ScalarDecl):
         return cuir.FieldAccess(name=node.name,
                                 offset=common.CartesianOffset.zero(),
                                 dtype=node.dtype)
     return cuir.ScalarAccess(name=node.name, dtype=node.dtype)
Ejemplo n.º 2
0
 def make_positional(self, axis: int) -> cuir.FieldAccess:
     axis_name = ["i", "j", "k"][axis]
     positional = self.positionals.setdefault(
         axis,
         cuir.Positional(
             name=self.new_symbol_name(f"axis_{axis_name}_index"),
             axis_name=axis_name),
     )
     return cuir.FieldAccess(
         name=positional.name,
         offset=common.CartesianOffset.zero(),
         dtype=common.DataType.INT32,
     )
Ejemplo n.º 3
0
 def visit_FieldAccess(
     self,
     node: oir.FieldAccess,
     *,
     ij_caches: Dict[str, cuir.IJCacheDecl],
     k_caches: Dict[str, cuir.KCacheDecl],
     ctx: "Context",
     **kwargs: Any,
 ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]:
     data_index = self.visit(
         node.data_index,
         ij_caches=ij_caches,
         k_caches=k_caches,
         ctx=ctx,
         **kwargs,
     )
     offset = self.visit(
         node.offset,
         ij_caches=ij_caches,
         k_caches=k_caches,
         ctx=ctx,
         **kwargs,
     )
     if node.name in ij_caches:
         return cuir.IJCacheAccess(
             name=ij_caches[node.name].name,
             offset=offset,
             dtype=node.dtype,
             data_index=data_index,
         )
     if node.name in k_caches:
         return cuir.KCacheAccess(
             name=k_caches[node.name].name,
             offset=offset,
             dtype=node.dtype,
             data_index=data_index,
         )
     ctx.accessed_fields.add(node.name)
     return cuir.FieldAccess(
         name=node.name,
         offset=offset,
         data_index=data_index,
         dtype=node.dtype,
     )
Ejemplo n.º 4
0
 def visit_FieldAccess(
     self,
     node: oir.FieldAccess,
     *,
     ij_caches: Dict[str, cuir.IJCacheDecl],
     k_caches: Dict[str, cuir.KCacheDecl],
     accessed_fields: Set[str],
     **kwargs: Any,
 ) -> Union[cuir.FieldAccess, cuir.IJCacheAccess, cuir.KCacheAccess]:
     if node.name in ij_caches:
         return cuir.IJCacheAccess(
             name=ij_caches[node.name].name,
             offset=node.offset,
             dtype=node.dtype,
         )
     if node.name in k_caches:
         return cuir.KCacheAccess(name=k_caches[node.name].name,
                                  offset=node.offset,
                                  dtype=node.dtype)
     accessed_fields.add(node.name)
     return cuir.FieldAccess(name=node.name,
                             offset=node.offset,
                             dtype=node.dtype)