def visit_Symbol(self, node: Symbol, *, location_stack): assert node.name in self.symbol_table if issubclass(self.symbol_table[node.name], Field) or issubclass( self.symbol_table[node.name], TemporaryField): return gtir.FieldAccess( name=node.name, location_type=location_stack[-1].chain.elements[-1], subscript=[gtir.LocationRef(name=location_stack[0].name)], ) # TODO(tehrengruber): just visit the subscript symbol elif issubclass(self.symbol_table[node.name], Location): return gtir.LocationRef(name=node.name) raise ValueError()
def visit_Subscript(self, node: Subscript, *, symtable, location_stack, **kwargs): value_decl = symtable[node.value.name] if isinstance(value_decl, (Argument, TemporaryFieldDecl, TemporarySparseFieldDecl)) and ( issubclass(value_decl.type_, (Field, TemporaryField, SparseField, TemporarySparseField)) ): assert len(node.indices) in [1, 2] assert all(isinstance(index, SymbolRef) for index in node.indices) assert all( isinstance(symtable[index.name], LocationSpecification) # type: ignore[union-attr] or isinstance(symtable[index.name], LocationComprehension) # type: ignore[union-attr] for index in node.indices ) # check arguments for consistency # TODO: lower IRs should check this too, currently without this check they just generate invalid code expected_index_types: Tuple[BuiltInTypeMeta, ...] if issubclass(value_decl.type_, SparseField) or issubclass( value_decl.type_, TemporarySparseField ): connectivity = value_decl.type_.args[0] assert issubclass(connectivity, Connectivity) expected_index_types = ( Location[connectivity.primary_location()], Location[connectivity.secondary_location()], ) elif issubclass(value_decl.type_, Field) or issubclass( value_decl.type_, TemporaryField ): expected_index_types = (Location[value_decl.type_.args[0]],) else: raise RuntimeError( f"Invalid symbol '{node.value.name}' in subscript expression ({node})" ) index_types = tuple(deduce_type(symtable, idx) for idx in node.indices) assert index_types == expected_index_types # TODO(tehrengruber): just visit the index symbol return gtir.FieldAccess( name=gtir.SymbolRef(node.value.name), subscript=[ gtir.LocationRef(name=index.name) for index in cast(List[SymbolRef], node.indices) ], location_type=location_stack[-1][1], ) raise ValueError()
def visit_SymbolRef(self, node: SymbolRef, *, symtable, location_stack, **kwargs): assert node.name in symtable # todo: Argument and TemporaryFieldDecl should have same base class # TODO: SparseField's should have slice syntax if isinstance( symtable[node.name], (Argument, TemporaryFieldDecl, TemporarySparseFieldDecl) ): if issubclass( symtable[node.name].type_, (Field, SparseField, TemporaryField, TemporarySparseField), ): return gtir.FieldAccess( name=node.name, location_type=location_stack[-1][1], subscript=[gtir.LocationRef(name=location_stack[0][0])], ) # TODO(tehrengruber): just visit the subscript symbol elif issubclass(symtable[node.name].type_, Connectivity): return gtir.SymbolRef(node.name) raise ValueError()
def visit_SubscriptMultiple(self, node: SubscriptMultiple, *, location_stack): assert node.value.name in self.symbol_table if issubclass(self.symbol_table[node.value.name], Field) or issubclass( self.symbol_table[node.value.name], TemporaryField): assert all( isinstance(index, Symbol) and issubclass(self.symbol_table[index.name], Location) for index in node.indices) # TODO(tehrengruber): just visit the index symbol return gtir.FieldAccess( name=node.value.name, subscript=[ gtir.LocationRef(name=index.name) for index in cast(List[Symbol], node.indices) ], location_type=location_stack[-1].chain.elements[-1], ) raise ValueError()
def visit_SubscriptCall( self, node: SubscriptCall, symtable, location_stack, inside_sparse_assign, neighbor_vector_access_expr_location_name, **kwargs, ): # TODO(tehrengruber): rework after NeighborVectorAccess in GTIR redesign assert inside_sparse_assign assert len(node.args) == 1 assert isinstance(node.args[0], List_) func = evaluate_const_expr(symtable, node.func) return gtir.NeighborVectorAccess( exprs=self.visit( node.args[0].elts, **{**kwargs, "symtable": symtable, "location_stack": location_stack}, ), location_ref=gtir.LocationRef(name=neighbor_vector_access_expr_location_name), location_type=location_stack[0][1], # wrong )