Exemplo n.º 1
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]]
        map_exit = graph.exit_node(map_entry)
        current_map = map_entry.map

        # Create new maps
        new_maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params[1:], current_map.range[1:])
        ]
        current_map.params = [current_map.params[0]]
        current_map.range = subsets.Range([current_map.range[0]])

        # Create new map entries and exits
        entries = [nodes.MapEntry(new_map) for new_map in new_maps]
        exits = [nodes.MapExit(new_map) for new_map in new_maps]

        # Create edges, abiding by the following rules:
        # 1. If there are no edges coming from the outside, use empty memlets
        # 2. Edges with IN_* connectors replicate along the maps
        # 3. Edges for dynamic map ranges replicate until reaching range(s)
        for edge in graph.out_edges(map_entry):
            graph.remove_edge(edge)
            graph.add_memlet_path(map_entry,
                                  *entries,
                                  edge.dst,
                                  src_conn=edge.src_conn,
                                  memlet=edge.data,
                                  dst_conn=edge.dst_conn)

        # Modify dynamic map ranges
        dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
        for edge in dynamic_edges:
            # Remove old edge and connector
            graph.remove_edge(edge)
            edge.dst.remove_in_connector(edge.dst_conn)

            # Propagate to each range it belongs to
            path = []
            for mapnode in [map_entry] + entries:
                path.append(mapnode)
                if any(edge.dst_conn in map(str, symbolic.symlist(r))
                       for r in mapnode.map.range):
                    graph.add_memlet_path(edge.src,
                                          *path,
                                          memlet=edge.data,
                                          src_conn=edge.src_conn,
                                          dst_conn=edge.dst_conn)

        # Create new map exits
        for edge in graph.in_edges(map_exit):
            graph.remove_edge(edge)
            graph.add_memlet_path(edge.src,
                                  *exits[::-1],
                                  map_exit,
                                  memlet=edge.data,
                                  src_conn=edge.src_conn,
                                  dst_conn=edge.dst_conn)
Exemplo n.º 2
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.node(self.state_id)
        map_entry = self.map_entry(sdfg)
        map_exit = graph.exit_node(map_entry)
        current_map = map_entry.map

        # Create new maps
        new_maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params[1:], current_map.range[1:])
        ]
        current_map.params = [current_map.params[0]]
        current_map.range = subsets.Range([current_map.range[0]])

        # Create new map entries and exits
        entries = [nodes.MapEntry(new_map) for new_map in new_maps]
        exits = [nodes.MapExit(new_map) for new_map in new_maps]

        # Create edges, abiding by the following rules:
        # 1. If there are no edges coming from the outside, use empty memlets
        # 2. Edges with IN_* connectors replicate along the maps
        # 3. Edges for dynamic map ranges replicate until reaching range(s)
        for edge in graph.out_edges(map_entry):
            graph.remove_edge(edge)
            graph.add_memlet_path(map_entry,
                                  *entries,
                                  edge.dst,
                                  src_conn=edge.src_conn,
                                  memlet=edge.data,
                                  dst_conn=edge.dst_conn)

        # Modify dynamic map ranges
        dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
        for edge in dynamic_edges:
            # Remove old edge and connector
            graph.remove_edge(edge)
            edge.dst.remove_in_connector(edge.dst_conn)

            # Propagate to each range it belongs to
            path = []
            for mapnode in [map_entry] + entries:
                path.append(mapnode)
                if any(edge.dst_conn in map(str, symbolic.symlist(r))
                       for r in mapnode.map.range):
                    graph.add_memlet_path(edge.src,
                                          *path,
                                          memlet=edge.data,
                                          src_conn=edge.src_conn,
                                          dst_conn=edge.dst_conn)

        # Create new map exits
        for edge in graph.in_edges(map_exit):
            graph.remove_edge(edge)
            graph.add_memlet_path(edge.src,
                                  *exits[::-1],
                                  map_exit,
                                  memlet=edge.data,
                                  src_conn=edge.src_conn,
                                  dst_conn=edge.dst_conn)

        from dace.sdfg.scope import ScopeTree
        scope = None
        queue: List[ScopeTree] = graph.scope_leaves()
        while len(queue) > 0:
            tnode = queue.pop()
            if tnode.entry == entries[-1]:
                scope = tnode
                break
            elif tnode.parent is not None:
                queue.append(tnode.parent)
        else:
            raise ValueError('Cannot find scope in state')

        consolidate_edges(sdfg, scope)

        return [map_entry] + entries
Exemplo n.º 3
0
    def _stripmine(self, sdfg, graph, candidate):
        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_node(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        target_dim = map_entry.map.params[dim_idx]

        if self.tiling_type == 'ceilrange':
            new_dim, new_map, td_rng = self._create_ceil_range(
                sdfg, graph, map_entry)
        elif self.tiling_type == 'number_of_tiles':
            new_dim, new_map, td_rng = self._create_from_tile_numbers(
                sdfg, graph, map_entry)
        else:
            new_dim, new_map, td_rng = self._create_strided_range(
                sdfg, graph, map_entry)

        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        td_to_new_approx = td_rng[1]
        if isinstance(td_to_new_approx, dace.symbolic.SymExpr):
            td_to_new_approx = td_to_new_approx.approx

        # Special case: If range is 1 and no prefix was specified, skip range
        if td_rng[0] == td_to_new_approx and target_dim == new_dim:
            map_entry.map.range = subsets.Range(
                [r for i, r in enumerate(map_entry.map.range) if i != dim_idx])
            map_entry.map.params = [
                p for i, p in enumerate(map_entry.map.params) if i != dim_idx
            ]
            if len(map_entry.map.params) == 0:
                raise ValueError('Strip-mining all dimensions of the map with '
                                 'empty tiles is disallowed')
        else:
            map_entry.map.range[dim_idx] = td_rng

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        sdutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        sdutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = {}
        entry_out_conn = {}
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn['IN_' + conn] = None
                    entry_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn[in_conn] = None
                if out_conn:
                    entry_out_conn[out_conn] = None
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = {}
        exit_out_conn = {}
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn['IN_' + conn] = None
                    exit_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = dst_conn
                    out_conn = dst_conn
                if in_conn:
                    exit_in_conn[in_conn] = None
                if out_conn:
                    exit_out_conn[out_conn] = None
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Skew if necessary
        if self.skew:
            xfh.offset_map(sdfg, graph, map_entry, dim_idx, td_rng[0])

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Exemplo n.º 4
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    no_init = Property(dtype=bool,
                       default=False,
                       desc='If enabled, does not create initialization states '
                       'for reduce nodes with identity')

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')

    import dace.libraries.standard as stdlib  # Avoid import loop
    _reduce = stdlib.Reduce()

    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(MapReduceFusion._tasklet,
                                   MapReduceFusion._tmap_exit,
                                   MapReduceFusion._in_array,
                                   MapReduceFusion._reduce,
                                   MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]]
        tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != reduce_node
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        tmem = next(e for e in graph.edges_between(tasklet, tmap_exit)
                    if e.data.data == in_array.data).data

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # If memlet already has WCR and it is different from reduce node,
        # do not match
        if tmem.wcr is not None and tmem.wcr != reduce_node.wcr:
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        reduce = candidate[MapReduceFusion._reduce]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]]

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        nodes_to_remove.append(reduce_node)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        # Find which indices should be removed from new memlet
        input_edge = graph.in_edges(reduce_node)[0]
        axes = reduce_node.axes or list(range(len(input_edge.data.subset)))
        array_edge = graph.out_edges(reduce_node)[0]

        # Delete relevant edges and nodes
        graph.remove_nodes_from(nodes_to_remove)

        # Filter out reduced dimensions from subset
        filtered_subset = [
            dim for i, dim in enumerate(memlet_edge.data.subset)
            if i not in axes
        ]
        if len(filtered_subset) == 0:  # Output is a scalar
            filtered_subset = [(0, 0, 1)]

        # Modify edge from tasklet to map exit
        memlet_edge.data.data = out_array.data
        memlet_edge.data.wcr = reduce_node.wcr
        memlet_edge.data.subset = type(memlet_edge.data.subset)(filtered_subset)

        # Add edge from map exit to output array
        graph.add_edge(
            memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst,
            array_edge.dst_conn,
            Memlet.simple(array_edge.data.data,
                          array_edge.data.subset,
                          num_accesses=array_edge.data.num_accesses,
                          wcr_str=reduce_node.wcr))

        # Add initialization state as necessary
        if reduce_node.identity is not None:
            init_state = sdfg.add_state_before(graph)
            init_state.add_mapped_tasklet(
                'freduce_init',
                [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2]))
                 for i, r in enumerate(array_edge.data.subset)], {},
                'out = %s' % reduce_node.identity, {
                    'out':
                    Memlet.simple(
                        array_edge.data.data, ','.join([
                            'o%d' % i
                            for i in range(len(array_edge.data.subset))
                        ]))
                },
                external_edges=True)
Exemplo n.º 5
0
def merge_maps(
    graph: SDFGState,
    outer_map_entry: nd.MapEntry,
    outer_map_exit: nd.MapExit,
    inner_map_entry: nd.MapEntry,
    inner_map_exit: nd.MapExit,
    param_merge: Callable[[ParamsType, ParamsType],
                          ParamsType] = lambda p1, p2: p1 + p2,
    range_merge: Callable[[RangesType, RangesType],
                          RangesType] = lambda r1, r2: type(r1)
    (r1.ranges + r2.ranges)
) -> (nd.MapEntry, nd.MapExit):
    """ Merges two maps (their entries and exits). It is assumed that the
    operation is valid. """

    outer_map = outer_map_entry.map
    inner_map = inner_map_entry.map

    # Create merged map by inheriting attributes from outer map and using
    # the merge functions for parameters and ranges.
    merged_map = copy.deepcopy(outer_map)
    merged_map.label = outer_map.label
    merged_map.params = param_merge(outer_map.params, inner_map.params)
    merged_map.range = range_merge(outer_map.range, inner_map.range)

    merged_entry = nd.MapEntry(merged_map)
    merged_entry.in_connectors = outer_map_entry.in_connectors
    merged_entry.out_connectors = outer_map_entry.out_connectors

    merged_exit = nd.MapExit(merged_map)
    merged_exit.in_connectors = outer_map_exit.in_connectors
    merged_exit.out_connectors = outer_map_exit.out_connectors

    graph.add_nodes_from([merged_entry, merged_exit])

    # Handle the case of dynamic map inputs in the inner map
    inner_dynamic_map_inputs = dynamic_map_inputs(graph, inner_map_entry)
    for edge in inner_dynamic_map_inputs:
        remove_conn = (len(
            list(graph.out_edges_by_connector(edge.src, edge.src_conn))) == 1)
        conn_to_remove = edge.src_conn[4:]
        if remove_conn:
            merged_entry.remove_in_connector('IN_' + conn_to_remove)
            merged_entry.remove_out_connector('OUT_' + conn_to_remove)
        merged_entry.add_in_connector(
            edge.dst_conn, inner_map_entry.in_connectors[edge.dst_conn])
        outer_edge = next(
            graph.in_edges_by_connector(outer_map_entry,
                                        'IN_' + conn_to_remove))
        graph.add_edge(outer_edge.src, outer_edge.src_conn, merged_entry,
                       edge.dst_conn, outer_edge.data)
        if remove_conn:
            graph.remove_edge(outer_edge)

    # Redirect inner in edges.
    for edge in graph.out_edges(inner_map_entry):
        if edge.src_conn is None:  # Empty memlets
            graph.add_edge(merged_entry, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)
            continue

        # Get memlet path and edge
        path = graph.memlet_path(edge)
        ind = path.index(edge)
        # Add an edge directly from the previous source connector to the
        # destination
        graph.add_edge(merged_entry, path[ind - 1].src_conn, edge.dst,
                       edge.dst_conn, edge.data)

    # Redirect inner out edges.
    for edge in graph.in_edges(inner_map_exit):
        if edge.dst_conn is None:  # Empty memlets
            graph.add_edge(edge.src, edge.src_conn, merged_exit, edge.dst_conn,
                           edge.data)
            continue

        # Get memlet path and edge
        path = graph.memlet_path(edge)
        ind = path.index(edge)
        # Add an edge directly from the source to the next destination
        # connector
        graph.add_edge(edge.src, edge.src_conn, merged_exit,
                       path[ind + 1].dst_conn, edge.data)

    # Redirect outer edges.
    change_edge_dest(graph, outer_map_entry, merged_entry)
    change_edge_src(graph, outer_map_exit, merged_exit)

    # Clean-up
    graph.remove_nodes_from(
        [outer_map_entry, outer_map_exit, inner_map_entry, inner_map_exit])

    return merged_entry, merged_exit
Exemplo n.º 6
0
def merge_maps(
    graph: gr.OrderedMultiDiConnectorGraph,
    outer_map_entry: nd.MapEntry,
    outer_map_exit: nd.MapExit,
    inner_map_entry: nd.MapEntry,
    inner_map_exit: nd.MapExit,
    param_merge: Callable[[ParamsType, ParamsType],
                          ParamsType] = lambda p1, p2: p1 + p2,
    range_merge: Callable[[RangesType, RangesType],
                          RangesType] = lambda r1, r2: type(r1)
    (r1.ranges + r2.ranges)
) -> (nd.MapEntry, nd.MapExit):
    """ Merges two maps (their entries and exits). It is assumed that the
    operation is valid. """

    outer_map = outer_map_entry.map
    inner_map = inner_map_entry.map

    # Create merged map by inheriting attributes from outer map and using
    # the merge functions for parameters and ranges.
    merged_map = copy.deepcopy(outer_map)
    merged_map.label = 'merged_' + outer_map.label
    merged_map.params = param_merge(outer_map.params, inner_map.params)
    merged_map.range = range_merge(outer_map.range, inner_map.range)

    merged_entry = nd.MapEntry(merged_map)
    merged_entry.in_connectors = outer_map_entry.in_connectors
    merged_entry.out_connectors = outer_map_entry.out_connectors

    merged_exit = nd.MapExit(merged_map)
    merged_exit.in_connectors = outer_map_exit.in_connectors
    merged_exit.out_connectors = outer_map_exit.out_connectors

    graph.add_nodes_from([merged_entry, merged_exit])

    # Redirect inner in edges.
    inner_in_edges = graph.out_edges(inner_map_entry)
    for edge in graph.edges_between(outer_map_entry, inner_map_entry):
        if edge.dst_conn is None:  # Empty memlets
            out_conn = None
        else:
            out_conn = 'OUT_' + edge.dst_conn[3:]
        inner_edge = [e for e in inner_in_edges if e.src_conn == out_conn][0]
        graph.remove_edge(edge)
        graph.remove_edge(inner_edge)
        graph.add_edge(merged_entry, edge.src_conn, inner_edge.dst,
                       inner_edge.dst_conn, inner_edge.data)

    # Redirect inner out edges.
    inner_out_edges = graph.in_edges(inner_map_exit)
    for edge in graph.edges_between(inner_map_exit, outer_map_exit):
        if edge.src_conn is None:  # Empty memlets
            in_conn = None
        else:
            in_conn = 'IN_' + edge.src_conn[4:]
        inner_edge = [e for e in inner_out_edges if e.dst_conn == in_conn][0]
        graph.remove_edge(edge)
        graph.remove_edge(inner_edge)
        graph.add_edge(inner_edge.src, inner_edge.src_conn, merged_exit,
                       edge.dst_conn, inner_edge.data)

    # Redirect outer edges.
    change_edge_dest(graph, outer_map_entry, merged_entry)
    change_edge_src(graph, outer_map_exit, merged_exit)

    # Clean-up
    graph.remove_nodes_from(
        [outer_map_entry, outer_map_exit, inner_map_entry, inner_map_exit])

    return merged_entry, merged_exit
Exemplo n.º 7
0
class MapWCRFusion(pm.Transformation):
    """ Implements the map expanded-reduce fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else, and the
        reduction is divided to two maps with a WCR, denoting partial reduction.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            # Map, then partial reduction of axes
            sdutil.node_path_graph(
                MapWCRFusion._tasklet, MapWCRFusion._tmap_exit,
                MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry,
                MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet,
                MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit,
                MapWCRFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapWCRFusion._in_array]]
        rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]]
        reduce_edge = graph.in_edges(rmap_cr)[0]
        if reduce_edge.data.wcr is None:
            return False

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapWCRFusion._tasklet]
        map_exit = candidate[MapWCRFusion._tmap_exit]
        reduce = candidate[MapWCRFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # To apply, collapse the second map and then fuse the two resulting maps
        map_collapse = MapCollapse(
            self.sdfg_id, self.state_id, {
                MapCollapse._outer_map_entry:
                self.subgraph[MapWCRFusion._rmap_out_entry],
                MapCollapse._inner_map_entry:
                self.subgraph[MapWCRFusion._rmap_in_entry]
            }, 0)
        map_entry, _ = map_collapse.apply(sdfg)

        map_fusion = MapFusion(
            self.sdfg_id, self.state_id, {
                MapFusion._first_map_exit:
                self.subgraph[MapWCRFusion._tmap_exit],
                MapFusion._second_map_entry: graph.node_id(map_entry)
            }, 0)
        map_fusion.apply(sdfg)
Exemplo n.º 8
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)
    strided_map = Property(desc="Use strided map range (jump by vector length)"
                           " instead of modifying memlets",
                           dtype=bool,
                           default=False)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if memlet.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[memlet.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            if memlet.wcr is not None:
                return False

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, dim_step = map_entry.map.range[-1]
        if self.strided_map:
            map_entry.map.range[-1] = (dim_from, dim_to, vector_size)
        else:
            map_entry.map.range[-1] = (dim_from,
                                       (dim_to + 1) / vector_size - 1,
                                       dim_step)

        # TODO: Postamble and/or preamble non-vectorized map

        # Vectorize memlets adjacent to the tasklet.
        processed_edges = set()
        for edge in graph.all_edges(tasklet):
            _src, _, _dest, _, memlet = edge

            if memlet.data is None:  # Empty memlets
                continue

            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols

            if param not in symbols:
                continue
            try:
                # propagate vector length inside this SDFG
                for e in graph.memlet_tree(edge):
                    e.data.veclen = vector_size
                    if not self.strided_map and e not in processed_edges:
                        e.data.subset.replace({param: vector_size * param})
                        processed_edges.add(e)

                # propagate to the parent (TODO: handle multiple level of nestings)
                if self.propagate_parent and sdfg.parent is not None:
                    source_edge = graph.memlet_path(edge)[0]
                    sink_edge = graph.memlet_path(edge)[-1]

                    # Find parent Nested SDFG node
                    parent_node = next(n for n in sdfg.parent.nodes()
                                       if isinstance(n, nodes.NestedSDFG)
                                       and n.sdfg.name == sdfg.name)

                    # continue in propagating the vector length following the
                    # path that arrives to source_edge or starts from sink_edge
                    for pe in sdfg.parent.all_edges(parent_node):
                        if str(pe.dst_conn) == str(source_edge.src) or str(
                                pe.src_conn) == str(sink_edge.dst):
                            for ppe in sdfg.parent.memlet_tree(pe):
                                ppe.data.veclen = vector_size
                                if (not self.strided_map
                                        and ppe not in processed_edges):
                                    ppe.data.subset.replace(
                                        {param: vector_size * param})
                                    processed_edges.add(ppe)

            except AttributeError:
                raise
        return
Exemplo n.º 9
0
class AccumulateTransient(pattern_matching.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(AccumulateTransient._tasklet,
                                   AccumulateTransient._map_exit,
                                   AccumulateTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tasklet = graph.nodes()[candidate[AccumulateTransient._tasklet]]
        map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]]

        # Check if there is an accumulation output
        for _src, _, dest, _, memlet in graph.out_edges(tasklet):
            if memlet.wcr is not None and dest == map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[AccumulateTransient._tasklet]
        map_exit = candidate[AccumulateTransient._map_exit]
        outer_map_exit = candidate[AccumulateTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # Choose array
        array = self.array
        if array is None or len(array) == 0:
            map_exit = graph.node(self.subgraph[AccumulateTransient._map_exit])
            outer_map_exit = graph.node(
                self.subgraph[AccumulateTransient._outer_map_exit])
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import LocalStorage

        local_storage_subgraph = {
            LocalStorage._node_a:
            self.subgraph[AccumulateTransient._map_exit],
            LocalStorage._node_b:
            self.subgraph[AccumulateTransient._outer_map_exit]
        }
        sdfg_id = sdfg.sdfg_list.index(sdfg)
        in_local_storage = LocalStorage(sdfg_id, self.state_id,
                                        local_storage_subgraph,
                                        self.expr_index)
        in_local_storage.array = array
        in_local_storage.apply(sdfg)

        # Initialize transient to zero in case of summation
        # TODO: Initialize transient in other WCR types
        memlet = graph.in_edges(in_local_storage._data_node)[0].data
        if detect_reduction_type(memlet.wcr) == dtypes.ReductionType.Sum:
            in_local_storage._data_node.setzero = True
        else:
            warnings.warn('AccumulateTransient did not properly initialize'
                          'newly-created transient!')
Exemplo n.º 10
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)
    strided_map = Property(desc="Use strided map range (jump by vector length)"
                           " instead of modifying memlets",
                           dtype=bool,
                           default=True)
    preamble = Property(
        dtype=bool,
        default=None,
        allow_none=True,
        desc='Force creation or skipping a preamble map without vectors')
    postamble = Property(
        dtype=bool,
        default=None,
        allow_none=True,
        desc='Force creation or skipping a postamble map without vectors')

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Strided maps cannot be vectorized
        if map_entry.map.range[-1][2] != 1:
            return False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their contiguous dimension.
        for e, conntype in graph.all_edges_and_connectors(tasklet):

            # Cases that do not matter for vectorization
            if e.data.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[e.data.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            if e.data.wcr is not None:
                return False

            subset = e.data.subset
            array = sdfg.arrays[e.data.data]

            # If already vectorized or a pointer, do not apply
            if isinstance(conntype, (dtypes.vector, dtypes.pointer)):
                return False

            try:
                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if array.strides[idx] == 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if array.strides[idx] == 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len
        dim_from, dim_to, _ = map_entry.map.range[-1]

        # Determine whether to create preamble or postamble maps
        if self.preamble is not None:
            create_preamble = self.preamble
        else:
            create_preamble = not ((dim_from % vector_size == 0) == True
                                   or dim_from == 0)
        if self.postamble is not None:
            create_postamble = self.postamble
        else:
            if isinstance(dim_to, symbolic.SymExpr):
                create_postamble = (((dim_to.approx + 1) %
                                     vector_size == 0) == False)
            else:
                create_postamble = (((dim_to + 1) % vector_size == 0) == False)

        # Determine new range for vectorized map
        if self.strided_map:
            new_range = [dim_from, dim_to - vector_size + 1, vector_size]
        else:
            new_range = [
                dim_from // vector_size, ((dim_to + 1) // vector_size) - 1, 1
            ]

        # Create preamble non-vectorized map (replacing the original map)
        if create_preamble:
            old_scope = graph.scope_subgraph(map_entry, True, True)
            new_scope: ScopeSubgraphView = replicate_scope(
                sdfg, graph, old_scope)
            new_begin = dim_from + (vector_size - (dim_from % vector_size))
            map_entry.map.range[-1] = (dim_from, new_begin - 1, 1)
            # Replace map_entry with the replicated scope (so that the preamble
            # will usually come first in topological sort)
            map_entry = new_scope.entry
            tasklet = new_scope.nodes()[old_scope.nodes().index(tasklet)]
            new_range[0] = new_begin

        # Create postamble non-vectorized map
        if create_postamble:
            new_scope: ScopeSubgraphView = replicate_scope(
                sdfg, graph, graph.scope_subgraph(map_entry, True, True))
            dim_to_ex = dim_to + 1
            new_scope.entry.map.range[-1] = (dim_to_ex -
                                             (dim_to_ex % vector_size), dim_to,
                                             1)

        # Change the step of the inner-most dimension.
        map_entry.map.range[-1] = tuple(new_range)

        # Vectorize connectors adjacent to the tasklet.
        for edge in graph.all_edges(tasklet):
            connectors = (tasklet.in_connectors
                          if edge.dst == tasklet else tasklet.out_connectors)
            conn = edge.dst_conn if edge.dst == tasklet else edge.src_conn

            if edge.data.data is None:  # Empty memlets
                continue
            desc = sdfg.arrays[edge.data.data]
            contigidx = desc.strides.index(1)

            newlist = []

            lastindex = edge.data.subset[contigidx]
            if isinstance(lastindex, tuple):
                newlist = [(rb, re, rs) for rb, re, rs in edge.data.subset]
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                newlist = [(rb, rb, 1) for rb in edge.data.subset]
                symbols = symbolic.pystr_to_symbolic(lastindex).free_symbols

            if str(param) not in map(str, symbols):
                continue

            # Vectorize connector, if not already vectorized
            oldtype = connectors[conn]
            if oldtype is None or oldtype.type is None:
                oldtype = desc.dtype
            if isinstance(oldtype, dtypes.vector):
                continue

            connectors[conn] = dtypes.vector(oldtype, vector_size)

            # Modify memlet subset to match vector length
            if self.strided_map:
                rb = newlist[contigidx][0]
                if self.propagate_parent:
                    newlist[contigidx] = (rb / self.vector_len,
                                          rb / self.vector_len, 1)
                else:
                    newlist[contigidx] = (rb, rb + self.vector_len - 1, 1)
            else:
                rb = newlist[contigidx][0]
                if self.propagate_parent:
                    newlist[contigidx] = (rb, rb, 1)
                else:
                    newlist[contigidx] = (self.vector_len * rb,
                                          self.vector_len * rb +
                                          self.vector_len - 1, 1)
            edge.data.subset = subsets.Range(newlist)
            edge.data.volume = vector_size

        # Vector length propagation using data descriptors, recursive traversal
        # outwards
        if self.propagate_parent:
            for edge in graph.all_edges(tasklet):
                cursdfg = sdfg
                curedge = edge
                while cursdfg is not None:
                    arrname = curedge.data.data
                    dtype = cursdfg.arrays[arrname].dtype

                    # Change type and shape to vector
                    if not isinstance(dtype, dtypes.vector):
                        cursdfg.arrays[arrname].dtype = dtypes.vector(
                            dtype, vector_size)
                        new_shape = list(cursdfg.arrays[arrname].shape)
                        contigidx = cursdfg.arrays[arrname].strides.index(1)
                        new_shape[contigidx] /= vector_size
                        try:
                            new_shape[contigidx] = int(new_shape[contigidx])
                        except TypeError:
                            pass
                        cursdfg.arrays[arrname].shape = new_shape

                    propagation.propagate_memlets_sdfg(cursdfg)

                    # Find matching edge in parent
                    nsdfg = cursdfg.parent_nsdfg_node
                    if nsdfg is None:
                        break
                    tstate = cursdfg.parent
                    curedge = ([
                        e for e in tstate.in_edges(nsdfg)
                        if e.dst_conn == arrname
                    ] + [
                        e for e in tstate.out_edges(nsdfg)
                        if e.src_conn == arrname
                    ])[0]
                    cursdfg = cursdfg.parent_sdfg
Exemplo n.º 11
0
class AccumulateTransient(transformation.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    identity = SymbolicProperty(desc="Identity value to set",
                                default=None,
                                allow_none=True)

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(AccumulateTransient._map_exit,
                                   AccumulateTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            AccumulateTransient._outer_map_exit]]

        # Check if there is an accumulation output
        for e in graph.edges_between(map_exit, outer_map_exit):
            if e.data.wcr is not None:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        map_exit = candidate[AccumulateTransient._map_exit]
        outer_map_exit = candidate[AccumulateTransient._outer_map_exit]

        return ' -> '.join(str(node) for node in [map_exit, outer_map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.node(self.state_id)
        map_exit = graph.node(self.subgraph[AccumulateTransient._map_exit])
        outer_map_exit = graph.node(
            self.subgraph[AccumulateTransient._outer_map_exit])

        # Choose array
        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        data_node: nodes.AccessNode = OutLocalStorage.apply_to(
            sdfg,
            dict(array=array),
            verify=False,
            save=False,
            node_a=map_exit,
            node_b=outer_map_exit)

        if self.identity is None:
            warnings.warn('AccumulateTransient did not properly initialize '
                          'newly-created transient!')
            return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        temp_array: Array = sdfg.arrays[data_node.data]

        init_state.add_mapped_tasklet(
            name='acctrans_init',
            map_ranges={
                '_o%d' % i: '0:%s' % symstr(d)
                for i, d in enumerate(temp_array.shape)
            },
            inputs={},
            code='out = %s' % self.identity,
            outputs={
                'out':
                dace.Memlet.simple(data=data_node.data,
                                   subset_str=','.join(
                                       ['0:%d' % i for i in temp_array.shape]))
            },
            external_edges=True)
Exemplo n.º 12
0
    def fuse(self, sdfg, graph, map_entries, do_not_override=None, **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        do_not_override = do_not_override or []

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # See function documentation for an explanation of these variables
        node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph,
                                                        map_entries)
        (in_nodes, intermediate_nodes, out_nodes) = node_config

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.copy_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.copy_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=src,
                                       new_src_conn=None,
                                       new_data=mm)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.copy_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.copy_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        # for each transient created, create a union
                        # of outgoing memlets' subsets. this is
                        # a cheap fix to override assignments in invariant
                        # dimensions
                        union = None
                        for oe in graph.out_edges(transients_created[dst]):
                            union = subsets.union(union, oe.data.subset)
                        inner_memlet = dcpy(edge.data)
                        for i, s in enumerate(edge.data.subset):
                            if i in invariant_dimensions[dst.label]:
                                inner_memlet.subset[i] = union[i]

                        inner_memlet.other_subset = dcpy(inner_memlet.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                break

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.copy_edge(graph,
                                       out_edge,
                                       new_src=edge.src,
                                       new_src_conn=edge.src_conn,
                                       new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.copy_edge(graph,
                                           edge,
                                           new_dst=global_map_exit,
                                           new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))

            # maps are now ready to be discarded
            # all connected edges will be finally removed as well
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            out_edges = graph.out_edges(node)

            # memlets of created transient:
            # correct data names
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

            # memlets of all in between transients:
            # offset memlets if array has been augmented
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                    # nested SDFG: adjust arrays connected
                    if isinstance(iedge.src, nodes.NestedSDFG):
                        nsdfg = iedge.src.sdfg
                        nested_data_name = edge.src_conn
                        self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                 nested_data_name)

                for cedge in out_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)
                        # nested SDFG: adjust arrays connected
                        if isinstance(edge.dst, nodes.NestedSDFG):
                            nsdfg = edge.dst.sdfg
                            nested_data_name = edge.dst_conn
                            self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data,
                                                     nested_data_name)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        if oedge.dst == global_map_exit and \
                                            oedge.data.other_subset is None:
                            oedge.data.other_subset = dcpy(oedge.data.subset)
                            oedge.data.other_subset.offset(min_offset, True)

        # consolidate edges if desired
        if self.consolidate:
            consolidate_edges_scope(graph, global_map_entry)
            consolidate_edges_scope(graph, global_map_exit)

        # propagate edges adjacent to global map entry and exit
        # if desired
        if self.propagate:
            _propagate_node(graph, global_map_entry)
            _propagate_node(graph, global_map_exit)

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry
        if self.schedule_innermaps is not None:
            for node in graph.scope_children()[global_map_entry]:
                if isinstance(node, nodes.MapEntry):
                    node.map.schedule = self.schedule_innermaps
Exemplo n.º 13
0
    def fuse(self, sdfg, graph, map_entries, do_not_override=[], **kwargs):
        """ takes the map_entries specified and tries to fuse maps.

            all maps have to be extended into outer and inner map
            (use MapExpansion as a pre-pass)

            Arrays that don't exist outside the subgraph get pushed
            into the map and their data dimension gets cropped.
            Otherwise the original array is taken.

            For every output respective connections are crated automatically.

            :param sdfg: SDFG
            :param graph: State
            :param map_entries: Map Entries (class MapEntry) of the outer maps
                                which we want to fuse
            :param do_not_override: List of data names whose corresponding nodes
                                    are fully contained within the subgraph
                                    but should not be augmented/transformed
                                    nevertheless.
        """

        # if there are no maps, return immediately
        if len(map_entries) == 0:
            return

        # get maps and map exits
        maps = [map_entry.map for map_entry in map_entries]
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]

        # re-construct the map subgraph if necessary
        try:
            self.subgraph
        except AttributeError:
            subgraph_nodes = set()
            scope_dict = graph.scope_dict(node_to_children=True)
            for node in chain(map_entries, map_exits):
                subgraph_nodes.add(node)
                # add all border arrays
                for e in chain(graph.in_edges(node), graph.out_edges(node)):
                    subgraph_nodes.add(e.src)
                    subgraph_nodes.add(e.dst)
                try:
                    subgraph_nodes |= set(scope_dict[node])
                except KeyError:
                    pass
            self.subgraph = SubgraphView(graph, subgraph_nodes)

        # Nodes that flow into one or several maps but no data is flowed to them from any map
        in_nodes = set()

        # Nodes into which data is flowed but that no data flows into any map from them
        out_nodes = set()

        # Nodes that act as intermediate node - data flows from a map into them and then there
        # is an outgoing path into another map
        intermediate_nodes = set()

        ### NOTE:
        #- in_nodes, out_nodes, intermediate_nodes refer to the configuration of the final fused map
        #- in_nodes and out_nodes are trivially disjoint
        #- Intermediate_nodes and out_nodes are not necessarily disjoint
        #- Intermediate_nodes and in_nodes are disjoint by design.
        #  There could be a node that has both incoming edges from a map exit
        #  and from outside, but it is just treated as intermediate_node and handled
        #  automatically.

        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.in_edges(map_entry):
                in_nodes.add(edge.src)
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            # add to intermediate_nodes
                            intermediate_nodes.add(current_node)

                        else:
                            # add to out_nodes
                            out_nodes.add(current_node)
                for e in graph.in_edges(current_node):
                    if e.src not in map_exits:
                        raise NotImplementedError(
                            "Nodes between two maps to be"
                            "fused with *incoming* edges"
                            "from outside the maps are not"
                            "allowed yet.")

        # any intermediate_nodes currently in in_nodes shouldnt be there
        in_nodes -= intermediate_nodes

        if self.debug:
            print("SubgraphFusion::In_nodes", in_nodes)
            print("SubgraphFusion::Out_nodes", out_nodes)
            print("SubgraphFusion::Intermediate_nodes", intermediate_nodes)

        # all maps are assumed to have the same params and range in order
        global_map = nodes.Map(label="outer_fused",
                               params=maps[0].params,
                               ndrange=maps[0].range)
        global_map_entry = nodes.MapEntry(global_map)
        global_map_exit = nodes.MapExit(global_map)

        schedule = map_entries[0].schedule
        global_map_entry.schedule = schedule
        graph.add_node(global_map_entry)
        graph.add_node(global_map_exit)

        # next up, for any intermediate node, find whether it only appears
        # in the subgraph or also somewhere else / as an input
        # create new transients for nodes that are in out_nodes and
        # intermediate_nodes simultaneously
        # also check which dimensions of each transient data element correspond
        # to map axes and write this information into a dict.
        node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \
                                                    intermediate_nodes,\
                                                    map_entries, map_exits, \
                                                    do_not_override)

        (subgraph_contains_data, transients_created,
         invariant_dimensions) = node_info
        if self.debug:
            print(
                "SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict"
            )
            print(subgraph_contains_data)

        inconnectors_dict = {}
        # Dict for saving incoming nodes and their assigned connectors
        # Format: {access_node: (edge, in_conn, out_conn)}

        for map_entry, map_exit in zip(map_entries, map_exits):
            # handle inputs
            # TODO: dynamic map range -- this is fairly unrealistic in such a setting
            for edge in graph.in_edges(map_entry):
                src = edge.src
                mmt = graph.memlet_tree(edge)
                out_edges = [child.edge for child in mmt.root().children]

                if src in in_nodes:
                    in_conn = None
                    out_conn = None
                    if src in inconnectors_dict:
                        # no need to augment subset of outer edge.
                        # will do this at the end in one pass.

                        in_conn = inconnectors_dict[src][1]
                        out_conn = inconnectors_dict[src][2]
                        graph.remove_edge(edge)

                    else:
                        next_conn = global_map_entry.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_entry.add_in_connector(in_conn)
                        global_map_entry.add_out_connector(out_conn)

                        inconnectors_dict[src] = (edge, in_conn, out_conn)

                        # reroute in edge via global_map_entry
                        self.redirect_edge(graph, edge, new_dst = global_map_entry, \
                                                        new_dst_conn = in_conn)

                    # map out edges to new map
                    for out_edge in out_edges:
                        self.redirect_edge(graph, out_edge, new_src = global_map_entry, \
                                                            new_src_conn = out_conn)

                else:
                    # connect directly
                    for out_edge in out_edges:
                        mm = dcpy(out_edge.data)
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=src,
                                           new_data=mm)

                    graph.remove_edge(edge)

            for edge in graph.out_edges(map_entry):
                # special case: for nodes that have no data connections
                if not edge.src_conn:
                    self.redirect_edge(graph, edge, new_src=global_map_entry)

            ######################################

            for edge in graph.in_edges(map_exit):
                if not edge.dst_conn:
                    # no destination connector, path ends here.
                    self.redirect_edge(graph, edge, new_dst=global_map_exit)
                    continue
                # find corresponding out_edges for current edge, cannot use mmt anymore
                out_edges = [
                    oedge for oedge in graph.out_edges(map_exit)
                    if oedge.src_conn[3:] == edge.dst_conn[2:]
                ]

                # Tuple to store in/out connector port that might be created
                port_created = None

                for out_edge in out_edges:
                    dst = out_edge.dst

                    if dst in intermediate_nodes & out_nodes:

                        # create connection through global map from
                        # dst to dst_transient that was created
                        dst_transient = transients_created[dst]
                        next_conn = global_map_exit.next_connector()
                        in_conn = 'IN_' + next_conn
                        out_conn = 'OUT_' + next_conn
                        global_map_exit.add_in_connector(in_conn)
                        global_map_exit.add_out_connector(out_conn)

                        inner_memlet = dcpy(edge.data)
                        inner_memlet.other_subset = dcpy(edge.data.subset)

                        e_inner = graph.add_edge(dst, None, global_map_exit,
                                                 in_conn, inner_memlet)
                        mm_outer = propagate_memlet(graph, inner_memlet, global_map_entry, \
                                                    union_inner_edges = False)

                        e_outer = graph.add_edge(global_map_exit, out_conn,
                                                 dst_transient, None, mm_outer)

                        # remove edge from dst to dst_transient that was created
                        # in intermediate preparation.
                        for e in graph.out_edges(dst):
                            if e.dst == dst_transient:
                                graph.remove_edge(e)
                                removed = True
                                break

                        if self.debug:
                            assert removed == True

                    # handle separately: intermediate_nodes and pure out nodes
                    # case 1: intermediate_nodes: can just redirect edge
                    if dst in intermediate_nodes:
                        self.redirect_edge(graph,
                                           out_edge,
                                           new_src=edge.src,
                                           new_src_conn=edge.src_conn,
                                           new_data=dcpy(edge.data))

                    # case 2: pure out node: connect to outer array node
                    if dst in (out_nodes - intermediate_nodes):
                        if edge.dst != global_map_exit:
                            next_conn = global_map_exit.next_connector()
                            in_conn = 'IN_' + next_conn
                            out_conn = 'OUT_' + next_conn
                            global_map_exit.add_in_connector(in_conn)
                            global_map_exit.add_out_connector(out_conn)
                            self.redirect_edge(graph,
                                               edge,
                                               new_dst=global_map_exit,
                                               new_dst_conn=in_conn)
                            port_created = (in_conn, out_conn)
                            #edge.dst = global_map_exit
                            #edge.dst_conn = in_conn

                        else:
                            conn_nr = edge.dst_conn[3:]
                            in_conn = port_created.st
                            out_conn = port_created.nd

                        # map
                        graph.add_edge(global_map_exit, out_conn, dst, None,
                                       dcpy(out_edge.data))
                        graph.remove_edge(out_edge)

                # remove the edge if it has not been used by any pure out node
                if not port_created:
                    graph.remove_edge(edge)

            # maps are now ready to be discarded
            graph.remove_node(map_entry)
            graph.remove_node(map_exit)

            # end main loop.

        # create a mapping from data arrays to offsets
        # for later memlet adjustments later
        min_offsets = dict()

        # do one pass to augment all transient arrays
        data_intermediate = set([node.data for node in intermediate_nodes])
        for data_name in data_intermediate:
            if subgraph_contains_data[data_name]:
                all_nodes = [
                    n for n in intermediate_nodes if n.data == data_name
                ]
                in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes)))

                in_edges_iter = iter(in_edges)
                in_edge = next(in_edges_iter)
                target_subset = dcpy(in_edge.data.subset)
                target_subset.pop(invariant_dimensions[data_name])
                ######
                while True:
                    try:  # executed if there are multiple in_edges
                        in_edge = next(in_edges_iter)
                        target_subset_curr = dcpy(in_edge.data.subset)
                        target_subset_curr.pop(invariant_dimensions[data_name])
                        target_subset = subsets.union(target_subset, \
                                                      target_subset_curr)
                    except StopIteration:
                        break

                min_offsets_cropped = target_subset.min_element_approx()
                # calculate the new transient array size.
                target_subset.offset(min_offsets_cropped, True)

                # re-add invariant dimensions with offset 0 and save to min_offsets
                min_offset = []
                index = 0
                for i in range(len(sdfg.data(data_name).shape)):
                    if i in invariant_dimensions[data_name]:
                        min_offset.append(0)
                    else:
                        min_offset.append(min_offsets_cropped[index])
                        index += 1

                min_offsets[data_name] = min_offset

                # determine the shape of the new array.
                new_data_shape = []
                index = 0
                for i, sz in enumerate(sdfg.data(data_name).shape):
                    if i in invariant_dimensions[data_name]:
                        new_data_shape.append(sz)
                    else:
                        new_data_shape.append(target_subset.size()[index])
                        index += 1

                new_data_strides = [
                    data._prod(new_data_shape[i + 1:])
                    for i in range(len(new_data_shape))
                ]

                new_data_totalsize = data._prod(new_data_shape)
                new_data_offset = [0] * len(new_data_shape)
                # augment.
                transient_to_transform = sdfg.data(data_name)
                transient_to_transform.shape = new_data_shape
                transient_to_transform.strides = new_data_strides
                transient_to_transform.total_size = new_data_totalsize
                transient_to_transform.offset = new_data_offset
                transient_to_transform.lifetime = dtypes.AllocationLifetime.Scope
                transient_to_transform.storage = self.transient_allocation

            else:
                # don't modify data container - array is needed outside
                # of subgraph.

                # hack: set lifetime to State if allocation has only been
                # scope so far to avoid allocation issues
                if sdfg.data(
                        data_name).lifetime == dtypes.AllocationLifetime.Scope:
                    sdfg.data(
                        data_name).lifetime = dtypes.AllocationLifetime.State

        # do one pass to adjust and the memlets of in-between transients
        for node in intermediate_nodes:
            # all incoming edges to node
            in_edges = graph.in_edges(node)
            # outgoing edges going to another fused part
            inter_edges = []
            # outgoing edges that exit global map
            out_edges = []
            for e in graph.out_edges(node):
                if e.dst == global_map_exit:
                    out_edges.append(e)
                else:
                    inter_edges.append(e)

            # offset memlets where necessary
            if subgraph_contains_data[node.data]:
                # get min_offset
                min_offset = min_offsets[node.data]
                # re-add invariant dimensions with offset 0
                for iedge in in_edges:
                    for edge in graph.memlet_tree(iedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                for cedge in inter_edges:
                    for edge in graph.memlet_tree(cedge):
                        if edge.data.data == node.data:
                            edge.data.subset.offset(min_offset, True)
                        elif edge.data.other_subset:
                            edge.data.other_subset.offset(min_offset, True)

                # if in_edges has several entries:
                # put other_subset into out_edges for correctness
                if len(in_edges) > 1:
                    for oedge in out_edges:
                        oedge.data.other_subset = dcpy(oedge.data.subset)
                        oedge.data.other_subset.offset(min_offset, True)

            # also correct memlets of created transient
            if node in transients_created:
                transient_in_edges = graph.in_edges(transients_created[node])
                transient_out_edges = graph.out_edges(transients_created[node])
                for edge in chain(transient_in_edges, transient_out_edges):
                    for e in graph.memlet_tree(edge):
                        if e.data.data == node.data:
                            e.data.data += '_OUT'

        # do one last pass to correct outside memlets adjacent to global map
        for out_connector in global_map_entry.out_connectors:
            # find corresponding in_connector
            # and the in-connecting edge
            in_connector = 'IN' + out_connector[3:]
            for iedge in graph.in_edges(global_map_entry):
                if iedge.dst_conn == in_connector:
                    in_edge = iedge

            # find corresponding out_connector
            # and all out-connecting edges that belong to it
            # count them
            oedge_counter = 0
            for oedge in graph.out_edges(global_map_entry):
                if oedge.src_conn == out_connector:
                    out_edge = oedge
                    oedge_counter += 1

            # do memlet propagation
            # if there are several out edges, else there is no need

            if oedge_counter > 1:
                memlet_out = propagate_memlet(dfg_state=graph,
                                              memlet=out_edge.data,
                                              scope_node=global_map_entry,
                                              union_inner_edges=True)
                # override number of accesses
                in_edge.data.volume = memlet_out.volume
                in_edge.data.subset = memlet_out.subset

        # create a hook for outside access to global_map
        self._global_map_entry = global_map_entry
Exemplo n.º 14
0
    def _stripmine(self, sdfg, graph, candidate):

        # Retrieve map entry and exit nodes.
        map_entry = graph.nodes()[candidate[StripMining._map_entry]]
        map_exit = graph.exit_node(map_entry)

        # Retrieve transformation properties.
        dim_idx = self.dim_idx
        new_dim_prefix = self.new_dim_prefix
        tile_size = self.tile_size
        divides_evenly = self.divides_evenly
        strided = self.strided

        tile_stride = self.tile_stride
        if tile_stride is None or len(tile_stride) == 0:
            tile_stride = tile_size

        # Retrieve parameter and range of dimension to be strip-mined.
        target_dim = map_entry.map.params[dim_idx]
        td_from, td_to, td_step = map_entry.map.range[dim_idx]

        # Create new map. Replace by cloning map object?
        new_dim = self._find_new_dim(sdfg, graph, map_entry, new_dim_prefix,
                                     target_dim)
        nd_from = 0
        if symbolic.pystr_to_symbolic(tile_stride) == 1:
            nd_to = td_to
        else:
            nd_to = symbolic.pystr_to_symbolic(
                'int_ceil(%s + 1 - %s, %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from),
                 tile_stride))
        nd_step = 1
        new_dim_range = (nd_from, nd_to, nd_step)
        new_map = nodes.Map(new_dim + '_' + map_entry.map.label, [new_dim],
                            subsets.Range([new_dim_range]))
        new_map_entry = nodes.MapEntry(new_map)
        new_map_exit = nodes.MapExit(new_map)

        # Change the range of the selected dimension to iterate over a single
        # tile
        if strided:
            td_from_new = symbolic.pystr_to_symbolic(new_dim)
            td_to_new_approx = td_to
            td_step = symbolic.pystr_to_symbolic(tile_size)
        else:
            td_from_new = symbolic.pystr_to_symbolic(
                '%s + %s * %s' %
                (symbolic.symstr(td_from), str(new_dim), tile_stride))
            td_to_new_exact = symbolic.pystr_to_symbolic(
                'min(%s + 1, %s + %s * %s + %s) - 1' %
                (symbolic.symstr(td_to), symbolic.symstr(td_from), tile_stride,
                 str(new_dim), tile_size))
            td_to_new_approx = symbolic.pystr_to_symbolic(
                '%s + %s * %s + %s - 1' %
                (symbolic.symstr(td_from), tile_stride, str(new_dim),
                 tile_size))
        if divides_evenly or strided:
            td_to_new = td_to_new_approx
        else:
            td_to_new = dace.symbolic.SymExpr(td_to_new_exact,
                                              td_to_new_approx)
        # Special case: If range is 1 and no prefix was specified, skip range
        if td_from_new == td_to_new_approx and target_dim == new_dim:
            map_entry.map.range = subsets.Range(
                [r for i, r in enumerate(map_entry.map.range) if i != dim_idx])
            map_entry.map.params = [
                p for i, p in enumerate(map_entry.map.params) if i != dim_idx
            ]
            if len(map_entry.map.params) == 0:
                raise ValueError('Strip-mining all dimensions of the map with '
                                 'empty tiles is disallowed')
        else:
            map_entry.map.range[dim_idx] = (td_from_new, td_to_new, td_step)

        # Make internal map's schedule to "not parallel"
        new_map.schedule = map_entry.map.schedule
        map_entry.map.schedule = dtypes.ScheduleType.Sequential

        # Redirect edges
        new_map_entry.in_connectors = dcpy(map_entry.in_connectors)
        sdutil.change_edge_dest(graph, map_entry, new_map_entry)
        new_map_exit.out_connectors = dcpy(map_exit.out_connectors)
        sdutil.change_edge_src(graph, map_exit, new_map_exit)

        # Create new entry edges
        new_in_edges = dict()
        entry_in_conn = {}
        entry_out_conn = {}
        for _src, src_conn, _dst, _, memlet in graph.out_edges(map_entry):
            if (src_conn is not None
                    and src_conn[:4] == 'OUT_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = src_conn[4:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_in_edges.keys():
                    old_subset = new_in_edges[key].subset
                    new_in_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    entry_in_conn['IN_' + conn] = None
                    entry_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements()
                    new_in_edges[key] = new_memlet
            else:
                if src_conn is not None and src_conn[:4] == 'OUT_':
                    conn = src_conn[4:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    entry_in_conn[in_conn] = None
                if out_conn:
                    entry_out_conn[out_conn] = None
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_entry.out_connectors = entry_out_conn
        map_entry.in_connectors = entry_in_conn
        for (_, in_conn, out_conn), memlet in new_in_edges.items():
            graph.add_edge(new_map_entry, out_conn, map_entry, in_conn, memlet)

        # Create new exit edges
        new_out_edges = dict()
        exit_in_conn = {}
        exit_out_conn = {}
        for _src, _, _dst, dst_conn, memlet in graph.in_edges(map_exit):
            if (dst_conn is not None
                    and dst_conn[:3] == 'IN_' and not isinstance(
                        sdfg.arrays[memlet.data], dace.data.Scalar)):
                new_subset = calc_set_image(
                    map_entry.map.params,
                    map_entry.map.range,
                    memlet.subset,
                )
                conn = dst_conn[3:]
                key = (memlet.data, 'IN_' + conn, 'OUT_' + conn)
                if key in new_out_edges.keys():
                    old_subset = new_out_edges[key].subset
                    new_out_edges[key].subset = calc_set_union(
                        old_subset, new_subset)
                else:
                    exit_in_conn['IN_' + conn] = None
                    exit_out_conn['OUT_' + conn] = None
                    new_memlet = dcpy(memlet)
                    new_memlet.subset = new_subset
                    if memlet.dynamic:
                        new_memlet.num_accesses = memlet.num_accesses
                    else:
                        new_memlet.num_accesses = new_memlet.num_elements()
                    new_out_edges[key] = new_memlet
            else:
                if dst_conn is not None and dst_conn[:3] == 'IN_':
                    conn = dst_conn[3:]
                    in_conn = 'IN_' + conn
                    out_conn = 'OUT_' + conn
                else:
                    in_conn = src_conn
                    out_conn = src_conn
                if in_conn:
                    exit_in_conn[in_conn] = None
                if out_conn:
                    exit_out_conn[out_conn] = None
                new_in_edges[(memlet.data, in_conn, out_conn)] = dcpy(memlet)
        new_map_exit.in_connectors = exit_in_conn
        map_exit.out_connectors = exit_out_conn
        for (_, in_conn, out_conn), memlet in new_out_edges.items():
            graph.add_edge(map_exit, out_conn, new_map_exit, in_conn, memlet)

        # Return strip-mined dimension.
        return target_dim, new_dim, new_map
Exemplo n.º 15
0
    def expand(self, sdfg, graph, map_entries, map_base_variables=None):
        """
        Expansion into outer and inner maps for each map in a specified set.
        The resulting outer maps all have same range and indices, corresponding
        variables and memlets get changed accordingly. The inner map contains
        the leftover dimensions
        :param sdfg: Underlying SDFG
        :param graph: Graph in which we expand
        :param map_entries: List of Map Entries(Type MapEntry) that we want to expand
        :param map_base_variables: Optional parameter. List of strings
                                   If None, then expand() searches for the maximal amount
                                   of equal map ranges and pushes those and their corresponding
                                   loop variables into the outer loop.
                                   If specified, then expand() pushes the ranges belonging
                                   to the loop iteration variables specified into the outer loop
                                   (For instance map_base_variables = ['i','j'] assumes that
                                   all maps have common iteration indices i and j with corresponding
                                   correct ranges)
        """

        maps = [entry.map for entry in map_entries]

        if not map_base_variables:
            # find the maximal subset of variables to expand
            # greedy if there exist multiple ranges that are equal in a map

            map_base_ranges = helpers.common_map_base_ranges(maps)
            reassignments = helpers.find_reassignment(maps, map_base_ranges)

            ##### first, regroup and reassign
            # create params_dict for every map
            # first, let us define the outer iteration variable names,
            # just take the first map and their indices at common ranges
            map_base_variables = []
            for rng in map_base_ranges:
                for i in range(len(maps[0].params)):
                    if maps[0].range[i] == rng and maps[0].params[
                            i] not in map_base_variables:
                        map_base_variables.append(maps[0].params[i])
                        break

            params_dict = {}
            if self.debug:
                print("MultiExpansion::Map_base_variables:", map_base_variables)
                print("MultiExpansion::Map_base_ranges:", map_base_ranges)
            for map in maps:
                # for each map create param dict, first assign identity
                params_dict_map = {param: param for param in map.params}
                # now look for the correct reassignment
                # for every element neq -1, need to change param to map_base_variables[]
                # if param already appears in own dict, do a swap
                # else we just replace it
                for i, reassignment in enumerate(reassignments[map]):
                    if reassignment == -1:
                        # nothing to do
                        pass
                    else:
                        current_var = map.params[i]
                        current_assignment = params_dict_map[current_var]
                        target_assignment = map_base_variables[reassignment]
                        if current_assignment != target_assignment:
                            if target_assignment in params_dict_map.values():
                                # do a swap
                                key1 = current_var
                                for key, value in params_dict_map.items():
                                    if value == target_assignment:
                                        key2 = key

                                value1 = params_dict_map[key1]
                                value2 = params_dict_map[key2]
                                params_dict_map[key1] = key2
                                params_dict_map[key2] = key1
                            else:
                                # just reassign
                                params_dict_map[current_var] = target_assignment

                # done, assign params_dict_map to the global one
                params_dict[map] = params_dict_map

            for map, map_entry in zip(maps, map_entries):
                map_scope = graph.scope_subgraph(map_entry)
                params_dict_map = params_dict[map]
                for firstp, secondp in params_dict_map.items():
                    if firstp != secondp:
                        replace(map_scope, firstp, '__' + firstp + '_fused')
                for firstp, secondp in params_dict_map.items():
                    if firstp != secondp:
                        replace(map_scope, '__' + firstp + '_fused', secondp)

                # now also replace the map variables inside maps
                for i in range(len(map.params)):
                    map.params[i] = params_dict_map[map.params[i]]

            if self.debug:
                print("MultiExpansion::Params replaced")

        else:
            # just calculate map_base_ranges
            # do a check whether all maps correct
            map_base_ranges = []

            map0 = maps[0]
            for var in map_base_variables:
                index = map0.params.index(var)
                map_base_ranges.append(map0.range[index])

            for map in maps:
                for var, rng in zip(map_base_variables, map_base_ranges):
                    assert map.range[map.params.index(var)] == rng

        # then expand all the maps
        for map, map_entry in zip(maps, map_entries):
            if map.get_param_num() == len(map_base_variables):
                # nothing to expand, continue
                continue

            map_exit = graph.exit_node(map_entry)
            # create two new maps, outer and inner
            params_outer = map_base_variables
            ranges_outer = map_base_ranges

            init_params_inner = []
            init_ranges_inner = []
            for param, rng in zip(map.params, map.range):
                if param in map_base_variables:
                    continue
                else:
                    init_params_inner.append(param)
                    init_ranges_inner.append(rng)

            params_inner = init_params_inner
            ranges_inner = subsets.Range(init_ranges_inner)
            inner_map = nodes.Map(label = map.label + '_inner',
                                  params = params_inner,
                                  ndrange = ranges_inner,
                                  schedule = dtypes.ScheduleType.Sequential \
                                             if self.sequential_innermaps \
                                             else dtypes.ScheduleType.Default)

            map.label = map.label + '_outer'
            map.params = params_outer
            map.range = ranges_outer

            # create new map entries and exits
            map_entry_inner = nodes.MapEntry(inner_map)
            map_exit_inner = nodes.MapExit(inner_map)

            # analogously to Map_Expansion
            for edge in graph.out_edges(map_entry):
                graph.remove_edge(edge)
                graph.add_memlet_path(map_entry,
                                      map_entry_inner,
                                      edge.dst,
                                      src_conn=edge.src_conn,
                                      memlet=edge.data,
                                      dst_conn=edge.dst_conn)

            dynamic_edges = dynamic_map_inputs(graph, map_entry)
            for edge in dynamic_edges:
                # Remove old edge and connector
                graph.remove_edge(edge)
                edge.dst._in_connectors.remove(edge.dst_conn)

                # Propagate to each range it belongs to
                path = []
                for mapnode in [map_entry, map_entry_inner]:
                    path.append(mapnode)
                    if any(edge.dst_conn in map(str, symbolic.symlist(r))
                           for r in mapnode.map.range):
                        graph.add_memlet_path(edge.src,
                                              *path,
                                              memlet=edge.data,
                                              src_conn=edge.src_conn,
                                              dst_conn=edge.dst_conn)

            for edge in graph.in_edges(map_exit):
                graph.remove_edge(edge)
                graph.add_memlet_path(edge.src,
                                      map_exit_inner,
                                      map_exit,
                                      memlet=edge.data,
                                      src_conn=edge.src_conn,
                                      dst_conn=edge.dst_conn)
Exemplo n.º 16
0
class StreamTransient(pattern_matching.Transformation):
    """ Implements the StreamTransient transformation, which adds a transient
        and stream nodes between nested maps that lead to a stream. The
        transient then acts as a local buffer.
    """

    with_buffer = Property(dtype=bool,
                           default=True,
                           desc="Use an intermediate buffer for accumulation")

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            StreamTransient._outer_map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(map_exit):
            if isinstance(sdfg.arrays[memlet.data],
                          data.Stream) and dest == outer_map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(map_exit):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data],
                                                      data.Stream):
                edge = e
                break
        tasklet_memlet = None
        for e in graph.out_edges(tasklet):
            tasklet_memlet = e.data
            if tasklet_memlet.data == memlet.data:
                break

        bbox = map_exit.map.range.bounding_box_size()
        bbox_approx = [symbolic.overapproximate(dim) for dim in bbox]
        dataname = memlet.data

        # Create the new node: Temporary stream and an access node
        newname, _ = sdfg.add_stream('trans_' + dataname,
                                     sdfg.arrays[memlet.data].dtype,
                                     1,
                                     bbox_approx[0], [1],
                                     transient=True,
                                     find_new_name=True)
        snode = graph.add_access(newname)

        to_stream_mm = copy.deepcopy(memlet)
        to_stream_mm.data = snode.data
        tasklet_memlet.data = snode.data

        if self.with_buffer:
            newname_arr, _ = sdfg.add_transient('strans_' + dataname,
                                                [bbox_approx[0]],
                                                sdfg.arrays[memlet.data].dtype,
                                                find_new_name=True)
            anode = graph.add_access(newname_arr)
            to_array_mm = copy.deepcopy(memlet)
            to_array_mm.data = anode.data
            graph.add_edge(snode, None, anode, None, to_array_mm)
        else:
            anode = snode

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, edge.src_conn, snode, None, to_stream_mm)
        graph.add_edge(anode, None, outer_map_exit, edge.dst_conn, memlet)

        return

    def modifies_graph(self):
        return True
Exemplo n.º 17
0
class BufferTiling(transformation.Transformation):
    """ Implements the buffer tiling transformation.

        BufferTiling tiles a buffer that is in between two maps, where the preceding map
        writes to the buffer and the succeeding map reads from it.
        It introduces additional computations in exchange for reduced memory footprint.
        Commonly used to make use of shared memory on GPUs.
    """

    _map1_exit = nodes.MapExit(nodes.Map('', [], []))
    _array = nodes.AccessNode('')
    _map2_entry = nodes.MapEntry(nodes.Map('', [], []))

    tile_sizes = ShapeProperty(dtype=tuple,
                               default=(128, 128, 128),
                               desc="Tile size per dimension")

    # Returns a list of graphs that represent the pattern
    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(
                BufferTiling._map1_exit,
                BufferTiling._array,
                BufferTiling._map2_entry,
            )
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map1_exit = graph.nodes()[candidate[BufferTiling._map1_exit]]
        map2_entry = graph.nodes()[candidate[BufferTiling._map2_entry]]

        for buf in graph.all_nodes_between(map1_exit, map2_entry):
            # Check that buffers are AccessNodes.
            if not isinstance(buf, nodes.AccessNode):
                return False

            # Check that buffers are transient.
            if not sdfg.arrays[buf.data].transient:
                return False

            # Check that buffers have exactly 1 input and 1 output edge.
            if graph.in_degree(buf) != 1:
                return False
            if graph.out_degree(buf) != 1:
                return False

            # Check that buffers are next to the maps.
            if graph.in_edges(buf)[0].src != map1_exit:
                return False
            if graph.out_edges(buf)[0].dst != map2_entry:
                return False

            # Check that the data consumed is provided.
            provided = graph.in_edges(buf)[0].data.subset
            consumed = graph.out_edges(buf)[0].data.subset
            if not provided.covers(consumed):
                return False

            # Check that buffers occur only once in this state.
            num_occurrences = len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == buf
            ])
            if num_occurrences > 1:
                return False
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        map1_exit = graph.nodes()[candidate[BufferTiling._map1_exit]]
        map2_entry = graph.nodes()[candidate[BufferTiling._map2_entry]]

        return " -> ".join(entry.map.label + ": " + str(entry.map.params)
                           for entry in [map1_exit, map2_entry])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map1_exit = graph.nodes()[self.subgraph[self._map1_exit]]
        map1_entry = graph.entry_node(map1_exit)
        map2_entry = graph.nodes()[self.subgraph[self._map2_entry]]
        buffers = graph.all_nodes_between(map1_exit, map2_entry)
        # Situation:
        # -> map1_entry -> ... -> map1_exit -> buffers -> map2_entry -> ...

        lower_extents = tuple(b - a for a, b in zip(
            map1_entry.range.min_element(), map2_entry.range.min_element()))
        upper_extents = tuple(a - b for a, b in zip(
            map1_entry.range.max_element(), map2_entry.range.max_element()))

        # Tile the first map with overlap
        MapTilingWithOverlap.apply_to(sdfg,
                                      map_entry=map1_entry,
                                      options={
                                          'tile_sizes': self.tile_sizes,
                                          'lower_overlap': lower_extents,
                                          'upper_overlap': upper_extents
                                      })
        tile_map1_exit = graph.out_edges(map1_exit)[0].dst
        tile_map1_entry = graph.entry_node(tile_map1_exit)
        tile_map1_entry.label = 'BufferTiling'

        # Tile the second map
        MapTiling.apply_to(sdfg,
                           map_entry=map2_entry,
                           options={
                               'tile_sizes': self.tile_sizes,
                               'tile_trivial': True
                           })
        tile_map2_entry = graph.in_edges(map2_entry)[0].src

        # Fuse maps
        some_buffer = next(
            iter(buffers))  # some dummy to pass to MapFusion.apply_to()
        MapFusion.apply_to(sdfg,
                           first_map_exit=tile_map1_exit,
                           array=some_buffer,
                           second_map_entry=tile_map2_entry)

        # Optimize the simple cases
        map1_entry.range.ranges = [
            (r[0], r[0], r[2]) if l_ext == 0 and u_ext == 0 and ts == 1 else r
            for r, l_ext, u_ext, ts in zip(map1_entry.range.ranges,
                                           lower_extents, upper_extents,
                                           self.tile_sizes)
        ]

        map2_entry.range.ranges = [
            (r[0], r[0], r[2]) if ts == 1 else r
            for r, ts in zip(map2_entry.range.ranges, self.tile_sizes)
        ]

        if any(ts == 1 for ts in self.tile_sizes):
            if any(r[0] == r[1] for r in map1_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, _map_entry=map1_entry)
            if any(r[0] == r[1] for r in map2_entry.map.range):
                TrivialMapElimination.apply_to(sdfg, _map_entry=map2_entry)
Exemplo n.º 18
0
class OnTheFlyMapFusion(Transformation):
    _first_map_entry = nodes.MapEntry(nodes.Map('', [], []))
    _first_tasklet = nodes.Tasklet('')
    _first_map_exit = nodes.MapExit(nodes.Map('', [], []))
    _array_access = nodes.AccessNode('')
    _second_map_entry = nodes.MapEntry(nodes.Map('', [], []))
    _second_tasklet = nodes.Tasklet('')

    @staticmethod
    def expressions():
        return [
            sdutils.node_path_graph(OnTheFlyMapFusion._first_map_entry,
                                    OnTheFlyMapFusion._first_tasklet,
                                    OnTheFlyMapFusion._first_map_exit,
                                    OnTheFlyMapFusion._array_access,
                                    OnTheFlyMapFusion._second_map_entry,
                                    OnTheFlyMapFusion._second_tasklet)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_map_entry = graph.node(
            candidate[OnTheFlyMapFusion._first_map_entry])
        first_tasklet = graph.node(candidate[OnTheFlyMapFusion._first_tasklet])
        first_map_exit = graph.node(
            candidate[OnTheFlyMapFusion._first_map_exit])
        array_access = graph.node(candidate[OnTheFlyMapFusion._array_access])

        if len(first_map_exit.in_connectors) != 1:
            return False

        if (graph.in_degree(array_access) != 1
                or graph.out_degree(array_access) != 1):
            return False
        return True

    @staticmethod
    def _memlet_offsets(base_memlet, offset_memlet):
        """ Compute subset offset of `offset_memlet` relative to `base_memlet`.
        """
        def offset(base_range, offset_range):
            b0, e0, s0 = base_range
            b1, e1, s1 = offset_range
            assert e1 - e0 == b1 - b0 and s0 == s1
            return int(e1 - e0)

        return tuple(
            offset(b, o) for b, o in zip(base_memlet.subset.ranges,
                                         offset_memlet.subset.ranges))

    @staticmethod
    def _update_map_connectors(state, array_access, first_map_entry,
                               second_map_entry):
        """ Remove unused connector (of the to-be-replaced array) from second
            map entry, add new connectors to second map entry for the inputs
            used in the first map’s tasklets.
        """
        # Remove edges and connectors from arrays access to second map entry
        for edge in state.edges_between(array_access, second_map_entry):
            state.remove_edge_and_connectors(edge)
        state.remove_node(array_access)

        # Add new connectors to second map
        # TODO: implement for the general case with random naming
        for edge in state.in_edges(first_map_entry):
            if second_map_entry.add_in_connector(edge.dst_conn):
                state.add_edge(edge.src, edge.src_conn, second_map_entry,
                               edge.dst_conn, edge.data)

    @staticmethod
    def _read_offsets(state, array_name, first_map_exit, second_map_entry):
        """ Compute offsets of read accesses in second map.
        """
        # Get output memlet of first tasklet
        output_edges = state.in_edges(first_map_exit)
        assert len(output_edges) == 1
        write_memlet = output_edges[0].data

        # Find read offsets by looping over second map entry connectors
        offsets = defaultdict(list)
        for edge in state.out_edges(second_map_entry):
            if edge.data.data == array_name:
                second_map_entry.remove_out_connector(edge.src_conn)
                state.remove_edge(edge)
                offset = OnTheFlyMapFusion._memlet_offsets(
                    write_memlet, edge.data)
                offsets[offset].append(edge)

        return offsets

    @staticmethod
    def _copy_first_map_contents(state, first_map_entry, first_map_exit):
        nodes = list(
            state.all_nodes_between(first_map_entry, first_map_exit) -
            {first_map_entry})
        new_nodes = [copy.deepcopy(node) for node in nodes]
        for node in new_nodes:
            state.add_node(node)
        id_map = {
            state.node_id(old): state.node_id(new)
            for old, new in zip(nodes, new_nodes)
        }

        def map(node):
            return state.node(id_map[state.node_id(node)])

        for edge in state.edges():
            if edge.src in nodes or edge.dst in nodes:
                src = map(edge.src) if edge.src in nodes else edge.src
                dst = map(edge.dst) if edge.dst in nodes else edge.dst
                state.add_edge(src, edge.src_conn, dst, edge.dst_conn,
                               copy.deepcopy(edge.data))

        return new_nodes

    def _replicate_first_map(self, sdfg, array_access, first_map_entry,
                             first_map_exit, second_map_entry):
        """ Replicate tasklet of first map for reach read access in second map.
        """
        state = sdfg.node(self.state_id)
        array_name = array_access.data
        array = sdfg.arrays[array_name]

        read_offsets = self._read_offsets(state, array_name, first_map_exit,
                                          second_map_entry)

        # Replicate first map tasklets once for each read offset access and
        # connect them to other tasklets accordingly
        for offset, edges in read_offsets.items():
            nodes = self._copy_first_map_contents(state, first_map_entry,
                                                  first_map_exit)
            tmp_name = sdfg.temp_data_name()
            sdfg.add_scalar(tmp_name, array.dtype, transient=True)
            tmp_access = state.add_access(tmp_name)

            for node in nodes:
                for edge in state.edges_between(node, first_map_exit):
                    state.add_edge(edge.src, edge.src_conn, tmp_access, None,
                                   dace.Memlet(tmp_name))
                    state.remove_edge(edge)

                for edge in state.edges_between(first_map_entry, node):
                    memlet = copy.deepcopy(edge.data)
                    memlet.subset.offset(list(offset), negative=False)
                    second_map_entry.add_out_connector(edge.src_conn)
                    state.add_edge(second_map_entry, edge.src_conn, node,
                                   edge.dst_conn, memlet)
                    state.remove_edge(edge)

            for edge in edges:
                state.add_edge(tmp_access, None, edge.dst, edge.dst_conn,
                               dace.Memlet(tmp_name))

    def apply(self, sdfg: dace.SDFG):
        state = sdfg.node(self.state_id)
        first_map_entry = state.node(self.subgraph[self._first_map_entry])
        first_tasklet = state.node(self.subgraph[self._first_tasklet])
        first_map_exit = state.node(self.subgraph[self._first_map_exit])
        array_access = state.node(self.subgraph[self._array_access])
        second_map_entry = state.node(self.subgraph[self._second_map_entry])

        self._update_map_connectors(state, array_access, first_map_entry,
                                    second_map_entry)

        self._replicate_first_map(sdfg, array_access, first_map_entry,
                                  first_map_exit, second_map_entry)

        state.remove_nodes_from(
            state.all_nodes_between(first_map_entry, first_map_exit)
            | {first_map_exit})