Example #1
0
    def apply(self, graph, sdfg):
        map1_exit = self.map1_exit
        map1_entry = graph.entry_node(map1_exit)
        map2_entry = self.map2_entry
        buffers = graph.all_nodes_between(map1_exit, map2_entry)
        # Situation:
        # -> map1_entry -> ... -> map1_exit -> buffers -> map2_entry -> ...

        lower_extents = tuple(b - a for a, b in zip(
            map1_entry.range.min_element(), map2_entry.range.min_element()))
        upper_extents = tuple(a - b for a, b in zip(
            map1_entry.range.max_element(), map2_entry.range.max_element()))

        # Tile the first map with overlap
        MapTilingWithOverlap.apply_to(sdfg,
                                      map_entry=map1_entry,
                                      options={
                                          'tile_sizes': self.tile_sizes,
                                          'lower_overlap': lower_extents,
                                          'upper_overlap': upper_extents
                                      })
        tile_map1_exit = graph.out_edges(map1_exit)[0].dst
        tile_map1_entry = graph.entry_node(tile_map1_exit)
        tile_map1_entry.label = 'BufferTiling'

        # Tile the second map
        MapTiling.apply_to(sdfg,
                           map_entry=map2_entry,
                           options={
                               'tile_sizes': self.tile_sizes,
                               'tile_trivial': True
                           })
        tile_map2_entry = graph.in_edges(map2_entry)[0].src

        # Fuse maps
        some_buffer = next(
            iter(buffers))  # some dummy to pass to MapFusion.apply_to()
        MapFusion.apply_to(sdfg,
                           first_map_exit=tile_map1_exit,
                           array=some_buffer,
                           second_map_entry=tile_map2_entry)

        # Optimize the simple cases
        map1_entry.range.ranges = [
            (r[0], r[0], r[2]) if l_ext == 0 and u_ext == 0 and ts == 1 else r
            for r, l_ext, u_ext, ts in zip(map1_entry.range.ranges,
                                           lower_extents, upper_extents,
                                           self.tile_sizes)
        ]

        map2_entry.range.ranges = [
            (r[0], r[0], r[2]) if ts == 1 else r
            for r, ts in zip(map2_entry.range.ranges, self.tile_sizes)
        ]

        if any(ts == 1 for ts in self.tile_sizes):
            if any(r[0] == r[1] for r in map1_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, map_entry=map1_entry)
            if any(r[0] == r[1] for r in map2_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, map_entry=map2_entry)
Example #2
0
def test_applyto_pattern():
    sdfg = dbladd.to_sdfg()
    sdfg.simplify()

    # Since there is only one state (thanks to StateFusion), we can use the
    # first one in the SDFG
    state = sdfg.node(0)

    # The multiplication map is called "_Mult__map" (see above graph), we can
    # query it
    mult_exit = next(
        n for n in state.nodes()
        if isinstance(n, dace.nodes.MapExit) and n.label == '_Mult__map')
    # Same goes for the addition entry
    add_entry = next(
        n for n in state.nodes()
        if isinstance(n, dace.nodes.MapEntry) and n.label == '_Add__map')
    # Since all redundant arrays have been removed by simplification pass,
    # we can get the only transient array that remains in the graph
    transient = next(aname for aname, desc in sdfg.arrays.items()
                     if desc.transient)
    access_node = next(
        n for n in state.nodes()
        if isinstance(n, dace.nodes.AccessNode) and n.data == transient)

    MapFusion.apply_to(sdfg,
                       first_map_exit=mult_exit,
                       array=access_node,
                       second_map_entry=add_entry)
Example #3
0
def test_applyto_enumerate():
    sdfg = dbladd.to_sdfg()
    sdfg.simplify()

    # Construct subgraph pattern
    pattern = sdutil.node_path_graph(dace.nodes.MapExit, dace.nodes.AccessNode,
                                     dace.nodes.MapEntry)
    for subgraph in enumerate_matches(sdfg, pattern):
        MapFusion.apply_to(sdfg,
                           first_map_exit=subgraph.source_nodes()[0],
                           array=next(n for n in subgraph.nodes()
                                      if isinstance(n, dace.nodes.AccessNode)),
                           second_map_entry=subgraph.sink_nodes()[0])