예제 #1
0
    def visit_Symbol(self, node: Symbol, *, location_stack):
        assert node.name in self.symbol_table
        if issubclass(self.symbol_table[node.name], Field) or issubclass(
                self.symbol_table[node.name], TemporaryField):
            return gtir.FieldAccess(
                name=node.name,
                location_type=location_stack[-1].chain.elements[-1],
                subscript=[gtir.LocationRef(name=location_stack[0].name)],
            )  # TODO(tehrengruber): just visit the subscript symbol
        elif issubclass(self.symbol_table[node.name], Location):
            return gtir.LocationRef(name=node.name)

        raise ValueError()
예제 #2
0
    def visit_Subscript(self, node: Subscript, *, symtable, location_stack, **kwargs):
        value_decl = symtable[node.value.name]
        if isinstance(value_decl, (Argument, TemporaryFieldDecl, TemporarySparseFieldDecl)) and (
            issubclass(value_decl.type_, (Field, TemporaryField, SparseField, TemporarySparseField))
        ):

            assert len(node.indices) in [1, 2]
            assert all(isinstance(index, SymbolRef) for index in node.indices)
            assert all(
                isinstance(symtable[index.name], LocationSpecification)  # type: ignore[union-attr]
                or isinstance(symtable[index.name], LocationComprehension)  # type: ignore[union-attr]
                for index in node.indices
            )

            # check arguments for consistency
            # TODO: lower IRs should check this too, currently without this check they just generate invalid code
            expected_index_types: Tuple[BuiltInTypeMeta, ...]
            if issubclass(value_decl.type_, SparseField) or issubclass(
                value_decl.type_, TemporarySparseField
            ):
                connectivity = value_decl.type_.args[0]
                assert issubclass(connectivity, Connectivity)
                expected_index_types = (
                    Location[connectivity.primary_location()],
                    Location[connectivity.secondary_location()],
                )
            elif issubclass(value_decl.type_, Field) or issubclass(
                value_decl.type_, TemporaryField
            ):
                expected_index_types = (Location[value_decl.type_.args[0]],)
            else:
                raise RuntimeError(
                    f"Invalid symbol '{node.value.name}' in subscript expression ({node})"
                )

            index_types = tuple(deduce_type(symtable, idx) for idx in node.indices)
            assert index_types == expected_index_types

            # TODO(tehrengruber): just visit the index symbol
            return gtir.FieldAccess(
                name=gtir.SymbolRef(node.value.name),
                subscript=[
                    gtir.LocationRef(name=index.name)
                    for index in cast(List[SymbolRef], node.indices)
                ],
                location_type=location_stack[-1][1],
            )

        raise ValueError()
예제 #3
0
    def visit_SymbolRef(self, node: SymbolRef, *, symtable, location_stack, **kwargs):
        assert node.name in symtable
        # todo: Argument and TemporaryFieldDecl should have same base class
        # TODO: SparseField's should have slice syntax
        if isinstance(
            symtable[node.name], (Argument, TemporaryFieldDecl, TemporarySparseFieldDecl)
        ):
            if issubclass(
                symtable[node.name].type_,
                (Field, SparseField, TemporaryField, TemporarySparseField),
            ):
                return gtir.FieldAccess(
                    name=node.name,
                    location_type=location_stack[-1][1],
                    subscript=[gtir.LocationRef(name=location_stack[0][0])],
                )  # TODO(tehrengruber): just visit the subscript symbol
            elif issubclass(symtable[node.name].type_, Connectivity):
                return gtir.SymbolRef(node.name)

        raise ValueError()
예제 #4
0
    def visit_SubscriptMultiple(self, node: SubscriptMultiple, *,
                                location_stack):
        assert node.value.name in self.symbol_table
        if issubclass(self.symbol_table[node.value.name], Field) or issubclass(
                self.symbol_table[node.value.name], TemporaryField):
            assert all(
                isinstance(index, Symbol)
                and issubclass(self.symbol_table[index.name], Location)
                for index in node.indices)
            # TODO(tehrengruber): just visit the index symbol
            return gtir.FieldAccess(
                name=node.value.name,
                subscript=[
                    gtir.LocationRef(name=index.name)
                    for index in cast(List[Symbol], node.indices)
                ],
                location_type=location_stack[-1].chain.elements[-1],
            )

        raise ValueError()
예제 #5
0
    def visit_SubscriptCall(
        self,
        node: SubscriptCall,
        symtable,
        location_stack,
        inside_sparse_assign,
        neighbor_vector_access_expr_location_name,
        **kwargs,
    ):
        # TODO(tehrengruber): rework after NeighborVectorAccess in GTIR redesign
        assert inside_sparse_assign
        assert len(node.args) == 1
        assert isinstance(node.args[0], List_)
        func = evaluate_const_expr(symtable, node.func)

        return gtir.NeighborVectorAccess(
            exprs=self.visit(
                node.args[0].elts,
                **{**kwargs, "symtable": symtable, "location_stack": location_stack},
            ),
            location_ref=gtir.LocationRef(name=neighbor_vector_access_expr_location_name),
            location_type=location_stack[0][1],  # wrong
        )