Example #1
0
 def _get_access_collection(
     self, node:
     "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]"
 ) -> AccessCollector.CartesianAccessCollection:
     if isinstance(node, SDFG):
         res = AccessCollector.CartesianAccessCollection([])
         for node in node.states()[0].nodes():
             if isinstance(
                     node,
                 (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)):
                 collection = self._get_access_collection(node)
                 res._ordered_accesses.extend(collection._ordered_accesses)
         return res
     elif isinstance(node, HorizontalExecutionLibraryNode):
         if id(node.oir_node) not in self._access_collection_cache:
             self._access_collection_cache[id(
                 node.oir_node)] = AccessCollector.apply(
                     node.oir_node).cartesian_accesses()
         return self._access_collection_cache[id(node.oir_node)]
     else:
         assert isinstance(node, VerticalLoopLibraryNode)
         res = AccessCollector.CartesianAccessCollection([])
         for _, sdfg in node.sections:
             collection = self._get_access_collection(sdfg)
             res._ordered_accesses.extend(collection._ordered_accesses)
         return res
Example #2
0
def offsets_match(left: HorizontalExecutionLibraryNode,
                  right: HorizontalExecutionLibraryNode) -> bool:
    left_accesses = AccessCollector.apply(left.as_oir())
    right_accesses = AccessCollector.apply(right.as_oir())
    conflicting = read_after_write_conflicts(
        ij_offsets(left_accesses.write_offsets()),
        ij_offsets(
            right_accesses.read_offsets())) | write_after_read_conflicts(
                ij_offsets(left_accesses.read_offsets()),
                ij_offsets(right_accesses.write_offsets()))
    return not conflicting
Example #3
0
def nodes_extent_calculation(
    nodes: Collection[Union["VerticalLoopLibraryNode",
                            "HorizontalExecutionLibraryNode"]]
) -> Dict[str, Extent]:
    field_extents: Dict[str, Extent] = dict()
    inner_nodes = []
    from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode

    for node in nodes:
        if isinstance(node, VerticalLoopLibraryNode):
            for _, section_sdfg in node.sections:
                for he in (ln for ln, _ in section_sdfg.all_nodes_recursive()
                           if isinstance(ln, dace.nodes.LibraryNode)):
                    inner_nodes.append(he)
        else:
            assert isinstance(node, HorizontalExecutionLibraryNode)
            inner_nodes.append(node)
    for node in inner_nodes:
        access_collection = AccessCollector.apply(node.oir_node)
        block_extent = node.extent
        if block_extent is not None:
            for acc in access_collection.ordered_accesses():
                offset_extent = acc.to_extent(block_extent) | Extent.zeros(2)
                field_extents.setdefault(acc.field, offset_extent)
                field_extents[acc.field] |= offset_extent

    return field_extents
Example #4
0
def test_access_collector():
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="tmp", right__name="foo", right__offset__i=1),
                AssignStmtFactory(left__name="bar", right__name="tmp"),
            ]),
            HorizontalExecutionFactory(body=[
                MaskStmtFactory(
                    body=[
                        AssignStmtFactory(left__name="baz",
                                          right__name="tmp",
                                          right__offset__j=1),
                    ],
                    mask=FieldAccessFactory(
                        name="mask",
                        dtype=DataType.BOOL,
                        offset__i=-1,
                        offset__j=-1,
                        offset__k=1,
                    ),
                )
            ], ),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )
    read_offsets = {
        "tmp": {(0, 0, 0), (0, 1, 0)},
        "foo": {(1, 0, 0)},
        "mask": {(-1, -1, 1)}
    }
    write_offsets = {
        "tmp": {(0, 0, 0)},
        "bar": {(0, 0, 0)},
        "baz": {(0, 0, 0)}
    }
    offsets = {
        "tmp": {(0, 0, 0), (0, 1, 0)},
        "foo": {(1, 0, 0)},
        "bar": {(0, 0, 0)},
        "baz": {(0, 0, 0)},
        "mask": {(-1, -1, 1)},
    }
    ordered_accesses = [
        GeneralAccess(field="foo", offset=(1, 0, 0), is_write=False),
        GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=True),
        GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=False),
        GeneralAccess(field="bar", offset=(0, 0, 0), is_write=True),
        GeneralAccess(field="mask", offset=(-1, -1, 1), is_write=False),
        GeneralAccess(field="tmp", offset=(0, 1, 0), is_write=False),
        GeneralAccess(field="baz", offset=(0, 0, 0), is_write=True),
    ]

    result = AccessCollector.apply(testee)
    assert result.read_offsets() == read_offsets
    assert result.write_offsets() == write_offsets
    assert result.offsets() == offsets
    assert result.ordered_accesses() == ordered_accesses
Example #5
0
def get_access_collection(
    node: Union[dace.SDFG, "HorizontalExecutionLibraryNode", "VerticalLoopLibraryNode"]
):
    from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode

    if isinstance(node, dace.SDFG):
        res = AccessCollector.CartesianAccessCollection([])
        for node in node.states()[0].nodes():
            if isinstance(node, (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)):
                collection = get_access_collection(node)
                res._ordered_accesses.extend(collection._ordered_accesses)
        return res
    elif isinstance(node, HorizontalExecutionLibraryNode):
        return AccessCollector.apply(node.oir_node)
    else:
        assert isinstance(node, VerticalLoopLibraryNode)
        res = AccessCollector.CartesianAccessCollection([])
        for _, sdfg in node.sections:
            collection = get_access_collection(sdfg)
            res._ordered_accesses.extend(collection._ordered_accesses)
        return res
Example #6
0
 def _get_access_collection(
     self, node: "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]"
 ) -> AccessCollector.GeneralAccessCollection:
     if isinstance(node, SDFG):
         res = AccessCollector.GeneralAccessCollection([])
         for n, _ in node.all_nodes_recursive():
             if isinstance(n, (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)):
                 collection = self._get_access_collection(n)
                 res._ordered_accesses.extend(collection._ordered_accesses)
         return res
     elif isinstance(node, HorizontalExecutionLibraryNode):
         if id(node.oir_node) not in self._access_collection_cache:
             self._access_collection_cache[id(node.oir_node)] = AccessCollector.apply(
                 node.oir_node
             )
         return self._access_collection_cache[id(node.oir_node)]
     else:
         assert isinstance(node, VerticalLoopLibraryNode)
         res = AccessCollector.GeneralAccessCollection([])
         for _, sdfg in node.sections:
             collection = self._get_access_collection(sdfg)
             res._ordered_accesses.extend(collection._ordered_accesses)
         return res
Example #7
0
def nodes_extent_calculation(
    nodes: Collection[Union["VerticalLoopLibraryNode", "HorizontalExecutionLibraryNode"]]
) -> Dict[str, Tuple[Tuple[int, int], Tuple[int, int]]]:
    access_spaces: Dict[str, Tuple[Tuple[int, int], ...]] = dict()
    inner_nodes = []
    from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode

    for node in nodes:
        if isinstance(node, VerticalLoopLibraryNode):
            for _, section_sdfg in node.sections:
                for he in (
                    ln
                    for ln, _ in section_sdfg.all_nodes_recursive()
                    if isinstance(ln, dace.nodes.LibraryNode)
                ):
                    inner_nodes.append(he)
        else:
            assert isinstance(node, HorizontalExecutionLibraryNode)
            inner_nodes.append(node)
    for node in inner_nodes:
        access_collection = AccessCollector.apply(node.oir_node)
        iteration_space = node.iteration_space
        if iteration_space is not None:
            for name, offsets in access_collection.offsets().items():
                for off in offsets:
                    access_extent = (
                        (
                            iteration_space.i_interval.start.offset + off[0],
                            iteration_space.i_interval.end.offset + off[0],
                        ),
                        (
                            iteration_space.j_interval.start.offset + off[1],
                            iteration_space.j_interval.end.offset + off[1],
                        ),
                    )
                    if name not in access_spaces:
                        access_spaces[name] = access_extent
                    access_spaces[name] = tuple(
                        (min(asp[0], ext[0]), max(asp[1], ext[1]))
                        for asp, ext in zip(access_spaces[name], access_extent)
                    )

    return {
        name: ((-asp[0][0], asp[0][1]), (-asp[1][0], asp[1][1]))
        for name, asp in access_spaces.items()
    }
Example #8
0
def test_stencil_extents_region(mask, offset, access_extent):
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(left__name="tmp", right__name="input")
            ]),
            HorizontalExecutionFactory(body=[
                HorizontalRestrictionFactory(
                    mask=mask,
                    body=[
                        AssignStmtFactory(left__name="tmp",
                                          right__name="input",
                                          right__offset__i=offset)
                    ],
                ),
            ]),
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="output", right__name="tmp", right__offset__i=1)
            ]),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )

    block_extents = compute_horizontal_block_extents(testee)
    hexecs = testee.vertical_loops[0].sections[0].horizontal_executions
    mask_read_accesses = AccessCollector.apply(hexecs[1].body[0])
    input_access = next(
        iter(acc for acc in mask_read_accesses.ordered_accesses()
             if acc.field == "input"))

    block_extent = ((0, 1), (0, 0))
    assert block_extents[id(hexecs[1])] == block_extent
    if access_extent is not None:
        assert input_access.to_extent(Extent(block_extent)) == access_extent
    else:
        assert input_access.to_extent(Extent(block_extent)) is None