Exemple #1
0
class AccumulateTransient(pattern_matching.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tasklet = graph.nodes()[candidate[StreamTransient._tasklet]]
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(tasklet):
            if memlet.wcr is not None and dest == map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(tasklet):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == map_exit and e.data.wcr is not None:
                break
        out_memlet = None
        for e in graph.out_edges(map_exit):
            out_memlet = e.data
            if out_memlet.data == memlet.data:
                edge = e
                break
        dataname = memlet.data

        # Create a new node with the same size as the output
        newdata = sdfg.add_array('trans_' + dataname,
                                 sdfg.arrays[memlet.data].shape,
                                 sdfg.arrays[memlet.data].dtype,
                                 transient=True)
        dnode = nodes.AccessNode('trans_' + dataname)

        to_data_mm = copy.deepcopy(memlet)
        to_data_mm.data = dnode.data
        to_data_mm.num_accesses = memlet.num_elements()

        to_exit_mm = copy.deepcopy(out_memlet)
        to_exit_mm.num_accesses = out_memlet.num_elements()
        memlet.data = dnode.data

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, edge.src_conn, dnode, None, to_data_mm)
        graph.add_edge(dnode, None, outer_map_exit, edge.dst_conn, to_exit_mm)

        return

    def modifies_graph(self):
        return True
Exemple #2
0
class StreamTransient(pattern_matching.Transformation):
    """ Implements the StreamTransient transformation, which adds a transient
        stream node between nested maps that lead to a stream. The transient
        then acts as a local buffer.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            StreamTransient._outer_map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(map_exit):
            if isinstance(sdfg.arrays[memlet.data],
                          data.Stream) and dest == outer_map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(map_exit):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data],
                                                      data.Stream):
                edge = e
                break
        tasklet_memlet = None
        for e in graph.out_edges(tasklet):
            tasklet_memlet = e.data
            if tasklet_memlet.data == memlet.data:
                break

        bbox = map_exit.map.range.bounding_box_size()
        bbox_approx = [symbolic.overapproximate(dim) for dim in bbox]
        dataname = memlet.data

        # Create the new node: Temporary stream and an access node
        newstream = sdfg.add_stream(
            'tile_' + dataname,
            sdfg.arrays[memlet.data].dtype,
            1,
            bbox_approx[0],
            [1],
            transient=True,
        )
        snode = nodes.AccessNode('tile_' + dataname)

        to_stream_mm = copy.deepcopy(memlet)
        to_stream_mm.data = snode.data
        tasklet_memlet.data = snode.data

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, None, snode, None, to_stream_mm)
        graph.add_edge(snode, None, outer_map_exit, None, memlet)

        return

    def modifies_graph(self):
        return True
Exemple #3
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False
        dtype = None

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if isinstance(sdfg.arrays[memlet.data], data.Stream):
                continue
            if memlet.wcr is not None:
                continue

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, _dim_step = map_entry.map.range[-1]
        map_entry.map.range[-1] = (dim_from, dim_to, vector_size)

        # Vectorize memlets adjacent to the tasklet.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):
            subset = memlet.subset
            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols
            if param in symbols:
                try:
                    memlet.veclen = vector_size
                except AttributeError:
                    return

        # TODO: Create new map for non-vectorizable part.

        return

    def modifies_graph(self):
        return True
Exemple #4
0
 def __init__(self, *args, **kwargs):
     self.entry = nodes.EntryNode()
     self.tasklet = nodes.Tasklet('_')
     self.exit = nodes.ExitNode()
     self.pairs = None
     super().__init__(*args, **kwargs)
Exemple #5
0
class MapWCRFusion(pm.Transformation):
    """ Implements the map expanded-reduce fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else, and the
        reduction is divided to two maps with a WCR, denoting partial reduction.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            # Map, then partial reduction of axes
            nxutil.node_path_graph(
                MapWCRFusion._tasklet, MapWCRFusion._tmap_exit,
                MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry,
                MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet,
                MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit,
                MapWCRFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapWCRFusion._in_array]]
        rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]]
        reduce_edge = graph.in_edges(rmap_cr)[0]
        if reduce_edge.data.wcr is None:
            return False

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapWCRFusion._tasklet]
        map_exit = candidate[MapWCRFusion._tmap_exit]
        reduce = candidate[MapWCRFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # To apply, collapse the second map and then fuse the two resulting maps
        map_collapse = MapCollapse(
            self.sdfg_id, self.state_id, {
                MapCollapse._outer_map_entry:
                self.subgraph[MapWCRFusion._rmap_out_entry],
                MapCollapse._inner_map_entry:
                self.subgraph[MapWCRFusion._rmap_in_entry]
            }, 0)
        map_entry, _ = map_collapse.apply(sdfg)

        map_fusion = MapFusion(
            self.sdfg_id, self.state_id, {
                MapFusion._first_map_exit:
                self.subgraph[MapWCRFusion._tmap_exit],
                MapFusion._second_map_entry: graph.node_id(map_entry)
            }, 0)
        map_fusion.apply(sdfg)
Exemple #6
0
 def __init__(self, *args, **kwargs):
     self._entry = nodes.EntryNode()
     self._tasklet = nodes.Tasklet('_')
     self._exit = nodes.ExitNode()
     super().__init__(*args, **kwargs)
Exemple #7
0
def add_indirection_subgraph(sdfg, graph, src, dst, memlet):
    """ Replaces the specified edge in the specified graph with a subgraph that
        implements indirection without nested AST memlet objects. """
    if not isinstance(memlet, astnodes._Memlet):
        raise TypeError("Expected memlet to be astnodes._Memlet")

    indirect_inputs = set()
    indirect_outputs = set()

    # Scheme for multi-array indirection:
    # 1. look for all arrays and accesses, create set of arrays+indices
    #    from which the index memlets will be constructed from
    # 2. each separate array creates a memlet, of which num_accesses = len(set)
    # 3. one indirection tasklet receives them all + original array and
    #    produces the right output index/range memlet
    #########################
    # Step 1
    accesses = OrderedDict()
    newsubset = dcpy(memlet.subset)
    for dimidx, dim in enumerate(memlet.subset):
        # Range/Index disambiguation
        direct_assignment = False
        if not isinstance(dim, tuple):
            dim = [dim]
            direct_assignment = True

        for i, r in enumerate(dim):
            for expr in sympy.preorder_traversal(r):
                if symbolic.is_sympy_userfunction(expr):
                    fname = expr.func.__name__
                    if fname not in accesses:
                        accesses[fname] = []

                    # Replace function with symbol (memlet local name to-be)
                    if expr.args in accesses[fname]:
                        aindex = accesses[fname].index(expr.args)
                        toreplace = 'index_' + fname + '_' + str(aindex)
                    else:
                        accesses[fname].append(expr.args)
                        toreplace = 'index_' + fname + '_' + str(
                            len(accesses[fname]) - 1)

                    if direct_assignment:
                        newsubset[dimidx] = r.subs(expr, toreplace)
                    else:
                        newsubset[dimidx][i] = r.subs(expr, toreplace)
    #########################
    # Step 2
    ind_inputs = {'__ind_' + memlet.local_name}
    ind_outputs = {'lookup'}
    # Add accesses to inputs
    for arrname, arr_accesses in accesses.items():
        for i in range(len(arr_accesses)):
            ind_inputs.add('index_%s_%d' % (arrname, i))

    tasklet = nd.Tasklet("Indirection", ind_inputs, ind_outputs)

    input_index_memlets = []
    for arrname, arr_accesses in accesses.items():
        arr = memlet.otherdeps[arrname]
        for i, access in enumerate(arr_accesses):
            # Memlet to load the indirection index
            indexMemlet = Memlet(arrname, 1, sbs.Indices(list(access)), 1)
            input_index_memlets.append(indexMemlet)
            graph.add_edge(src, None, tasklet, "index_%s_%d" % (arrname, i),
                           indexMemlet)

    #########################
    # Step 3
    # Create new tasklet that will perform the indirection
    indirection_ast = ast.parse("lookup = {arr}[{index}]".format(
        arr='__ind_' + memlet.local_name,
        index=', '.join([symbolic.symstr(s) for s in newsubset])))
    # Conserve line number of original indirection code
    tasklet.code = ast.copy_location(indirection_ast.body[0], memlet.ast)

    # Create transient variable to trigger the indirected load
    if memlet.num_accesses == 1:
        storage = sdfg.add_scalar('__' + memlet.local_name + '_value',
                                  memlet.data.dtype,
                                  transient=True)
    else:
        storage = sdfg.add_array('__' + memlet.local_name + '_value',
                                 memlet.data.dtype,
                                 storage=types.StorageType.Default,
                                 transient=True,
                                 shape=memlet.bounding_box_size())
    indirectRange = sbs.Range([(0, s - 1, 1) for s in storage.shape])
    dataNode = nd.AccessNode('__' + memlet.local_name + '_value')

    # Create memlet that depends on the full array that we look up in
    fullRange = sbs.Range([(0, s - 1, 1) for s in memlet.data.shape])
    fullMemlet = Memlet(memlet.dataname, memlet.num_accesses, fullRange,
                        memlet.veclen)
    graph.add_edge(src, None, tasklet, '__ind_' + memlet.local_name,
                   fullMemlet)

    # Memlet to store the final value into the transient, and to load it into
    # the tasklet that needs it
    indirectMemlet = Memlet('__' + memlet.local_name + '_value',
                            memlet.num_accesses, indirectRange, memlet.veclen)
    graph.add_edge(tasklet, 'lookup', dataNode, None, indirectMemlet)

    valueMemlet = Memlet('__' + memlet.local_name + '_value',
                         memlet.num_accesses, indirectRange, memlet.veclen)
    graph.add_edge(dataNode, None, dst, memlet.local_name, valueMemlet)
Exemple #8
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _reduce = nodes.Reduce('lambda: None', None)
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(MapReduceFusion._tasklet,
                                   MapReduceFusion._tmap_exit,
                                   MapReduceFusion._in_array,
                                   MapReduceFusion._reduce,
                                   MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]]
        tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != reduce_node
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        tmem = next(e for e in graph.edges_between(tasklet, tmap_exit)
                    if e.data.data == in_array.data).data

        # (strict) Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if strict and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # If memlet already has WCR and it is different from reduce node,
        # do not match
        if tmem.wcr is not None and tmem.wcr != reduce_node.wcr:
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        reduce = candidate[MapReduceFusion._reduce]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]]

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        nodes_to_remove.append(reduce_node)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        # Find which indices should be removed from new memlet
        input_edge = graph.in_edges(reduce_node)[0]
        axes = reduce_node.axes or list(range(input_edge.data.subset))
        array_edge = graph.out_edges(reduce_node)[0]

        # Delete relevant edges and nodes
        graph.remove_nodes_from(nodes_to_remove)

        # Filter out reduced dimensions from subset
        filtered_subset = [
            dim for i, dim in enumerate(memlet_edge.data.subset)
            if i not in axes
        ]
        if len(filtered_subset) == 0:  # Output is a scalar
            filtered_subset = [0]

        # Modify edge from tasklet to map exit
        memlet_edge.data.data = out_array.data
        memlet_edge.data.wcr = reduce_node.wcr
        memlet_edge.data.wcr_identity = reduce_node.identity
        memlet_edge.data.subset = type(
            memlet_edge.data.subset)(filtered_subset)

        # Add edge from map exit to output array
        graph.add_edge(
            memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst,
            array_edge.dst_conn,
            Memlet(array_edge.data.data, array_edge.data.num_accesses,
                   array_edge.data.subset, array_edge.data.veclen,
                   reduce_node.wcr, reduce_node.identity))
Exemple #9
0
def _build_dataflow_graph_recurse(sdfg, state, primitives, modules, superEntry,
                                  super_exit):
    # Array of pairs (exit node, memlet)
    exit_nodes = []

    if len(primitives) == 0:
        # Inject empty tasklets into empty states
        primitives = [astnodes._EmptyTaskletNode("Empty Tasklet", None)]

    for prim in primitives:
        label = prim.name

        # Expand node to get entry and exit points
        if isinstance(prim, astnodes._MapNode):
            if len(prim.children) == 0:
                raise ValueError("Map node expected to have children")
            mapNode = nd.Map(label,
                             prim.params,
                             prim.range,
                             is_async=prim.is_async)
            # Add connectors for inputs that exist as array nodes
            entry = nd.MapEntry(
                mapNode,
                _get_input_symbols(prim.inputs, prim.range.free_symbols))
            exit = nd.MapExit(mapNode)
        elif isinstance(prim, astnodes._ConsumeNode):
            if len(prim.children) == 0:
                raise ValueError("Consume node expected to have children")
            consumeNode = nd.Consume(label, (prim.params[1], prim.num_pes),
                                     prim.condition)
            entry = nd.ConsumeEntry(consumeNode)
            exit = nd.ConsumeExit(consumeNode)
        elif isinstance(prim, astnodes._ReduceNode):
            rednode = nd.Reduce(prim.ast, prim.axes, prim.identity)
            state.add_node(rednode)
            entry = rednode
            exit = rednode
        elif isinstance(prim, astnodes._TaskletNode):
            if isinstance(prim, astnodes._EmptyTaskletNode):
                tasklet = nd.EmptyTasklet(prim.name)
            else:
                # Remove memlets from tasklet AST
                if prim.language == types.Language.Python:
                    clean_code = MemletRemover().visit(prim.ast)
                    clean_code = ModuleInliner(modules).visit(clean_code)
                else:  # Use external code from tasklet definition
                    if prim.extcode is None:
                        raise SyntaxError("Cannot define an intrinsic "
                                          "tasklet without an implementation")
                    clean_code = prim.extcode
                tasklet = nd.Tasklet(
                    prim.name,
                    set(prim.inputs.keys()),
                    set(prim.outputs.keys()),
                    code=clean_code,
                    language=prim.language,
                    code_global=prim.gcode)  # TODO: location=prim.location

            # Need to add the tasklet in case we're in an empty state, where no
            # edge will be drawn to it
            state.add_node(tasklet)
            entry = tasklet
            exit = tasklet

        elif isinstance(prim, astnodes._NestedSDFGNode):
            prim.sdfg.parent = state
            prim.sdfg._parent_sdfg = sdfg
            prim.sdfg.update_sdfg_list([])
            nsdfg = nd.NestedSDFG(prim.name, prim.sdfg,
                                  set(prim.inputs.keys()),
                                  set(prim.outputs.keys()))
            state.add_node(nsdfg)
            entry = nsdfg
            exit = nsdfg

        elif isinstance(prim, astnodes._ProgramNode):
            return
        elif isinstance(prim, astnodes._ControlFlowNode):
            continue
        else:
            raise TypeError("Node type not implemented: " +
                            str(prim.__class__))

        # Add incoming edges
        for varname, memlet in prim.inputs.items():
            arr = memlet.dataname
            if (prim.parent is not None
                    and memlet.dataname in prim.parent.transients.keys()):
                node = input_node_for_array(state, memlet.dataname)

                # Add incoming edge into transient as well
                # FIXME: A bit hacked?
                if arr in prim.parent.inputs:
                    astmem = prim.parent.inputs[arr]
                    _add_astmemlet_edge(sdfg, state, superEntry, None, node,
                                        None, astmem)

                    # Remove local name from incoming edge to parent
                    prim.parent.inputs[arr].local_name = None
            elif superEntry:
                node = superEntry
            else:
                node = input_node_for_array(state, memlet.dataname)

            # Destination connector inference
            # Connected to a tasklet or a nested SDFG
            dst_conn = (memlet.local_name
                        if isinstance(entry, nd.CodeNode) else None)
            # Connected to a scope as part of its range
            if str(varname).startswith('__DACEIN_'):
                dst_conn = str(varname)[9:]
            # Handle special case of consume input stream
            if (isinstance(entry, nd.ConsumeEntry)
                    and memlet.data == prim.stream):
                dst_conn = 'IN_stream'

            # If a memlet that covers this input already exists, skip
            # generating this one; otherwise replace memlet with ours
            skip_incoming_edge = False
            remove_edge = None
            for e in state.edges_between(node, entry):
                if e.data.data != memlet.dataname or dst_conn != e.dst_conn:
                    continue
                if e.data.subset.covers(memlet.subset):
                    skip_incoming_edge = True
                    break
                elif memlet.subset.covers(e.data.subset):
                    remove_edge = e
                    break
                else:
                    print('WARNING: Performing bounding-box union on',
                          memlet.subset, 'and', e.data.subset, '(in)')
                    e.data.subset = sbs.bounding_box_union(
                        e.data.subset, memlet.subset)
                    e.data.num_accesses += memlet.num_accesses
                    skip_incoming_edge = True
                    break

            if remove_edge is not None:
                state.remove_edge(remove_edge)

            if skip_incoming_edge == False:
                _add_astmemlet_edge(sdfg, state, node, None, entry, dst_conn,
                                    memlet)

        # If there are no inputs, generate a dummy edge
        if superEntry and len(prim.inputs) == 0:
            state.add_edge(superEntry, None, entry, None, EmptyMemlet())

        if len(prim.children) > 0:
            # Recurse
            inner_outputs = _build_dataflow_graph_recurse(
                sdfg, state, prim.children, modules, entry, exit)
            # Infer output node for each memlet
            for i, (out_src, mem) in enumerate(inner_outputs):
                # If there is no such array in this primitive's outputs,
                # it's an external array (e.g., a map in a map). In this case,
                # connect to the exit node
                if mem.dataname in prim.outputs:
                    inner_outputs[i] = (out_src, prim.outputs[mem.dataname])
                else:
                    inner_outputs[i] = (out_src, mem)
        else:
            inner_outputs = [(exit, mem) for mem in prim.outputs.values()]

        # Add outgoing edges
        for out_src, astmem in inner_outputs:

            data = astmem.data
            dataname = astmem.dataname

            # If WCR is not none, it needs to be handled in the code. Check for
            # this after, as we only expect it for one distinct case
            wcr_was_handled = astmem.wcr is None

            # TODO: This is convoluted. We should find a more readable
            # way of connecting the outgoing edges.

            if super_exit is None:

                # Assert that we're in a top-level node
                if ((not isinstance(prim.parent, astnodes._ProgramNode)) and
                    (not isinstance(prim.parent, astnodes._ControlFlowNode))):
                    raise RuntimeError("Expected to be at the top node")

                # Looks hacky
                src_conn = (astmem.local_name if isinstance(
                    out_src, (nd.Tasklet, nd.NestedSDFG)) else None)

                # Here we just need to connect memlets directly to their
                # respective data nodes
                out_tgt = output_node_for_array(state, astmem.dataname)

                # If a memlet that covers this outuput already exists, skip
                # generating this one; otherwise replace memlet with ours
                skip_outgoing_edge = False
                remove_edge = None
                for e in state.edges_between(out_src, out_tgt):
                    if e.data.data != astmem.dataname or src_conn != e.src_conn:
                        continue
                    if e.data.subset.covers(astmem.subset):
                        skip_outgoing_edge = True
                        break
                    elif astmem.subset.covers(e.data.subset):
                        remove_edge = e
                        break
                    else:
                        print('WARNING: Performing bounding-box union on',
                              astmem.subset, 'and', e.data.subset, '(out)')
                        e.data.subset = sbs.bounding_box_union(
                            e.data.subset, astmem.subset)
                        e.data.num_accesses += astmem.num_accesses
                        skip_outgoing_edge = True
                        break

                if skip_outgoing_edge == True:
                    continue
                if remove_edge is not None:
                    state.remove_edge(remove_edge)

                _add_astmemlet_edge(sdfg,
                                    state,
                                    out_src,
                                    src_conn,
                                    out_tgt,
                                    None,
                                    astmem,
                                    wcr=astmem.wcr,
                                    wcr_identity=astmem.wcr_identity)
                wcr_was_handled = (True if astmem.wcr is not None else
                                   wcr_was_handled)

                # If the program defines another output, connect it too.
                # This refers to the case where we have streams, which
                # must define an input and output, and sometimes this output
                # is defined in pdp.outputs
                if (isinstance(out_tgt, nd.AccessNode)
                        and isinstance(out_tgt.desc(sdfg), dt.Stream)):
                    try:
                        stream_memlet = next(
                            v for k, v in prim.parent.outputs.items()
                            if k == out_tgt.data)
                        stream_output = output_node_for_array(
                            state, stream_memlet.dataname)
                        _add_astmemlet_edge(sdfg, state, out_tgt, None,
                                            stream_output, None, stream_memlet)
                    except StopIteration:  # Stream output not found, skip
                        pass

            else:  # We're in a nest

                if isinstance(prim, astnodes._ScopeNode):
                    # We're a map or a consume node, that needs to connect our
                    # exit to either an array or to the super_exit
                    if data.transient and dataname in prim.parent.transients:
                        # Connect the exit directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            out_tgt, None, astmem)
                    else:
                        # This is either a transient defined in an outer scope,
                        # or an I/O array, so redirect thruogh the exit node
                        _add_astmemlet_edge(sdfg, state, out_src, None,
                                            super_exit, None, astmem)
                        # Instruct outer recursion layer to continue the route
                        exit_nodes.append((super_exit, astmem))
                elif isinstance(
                        prim,
                    (astnodes._TaskletNode, astnodes._NestedSDFGNode)):
                    # We're a tasklet, and need to connect either to the exit
                    # if the array is I/O or is defined in a scope further out,
                    # or directly to the transient if it's defined locally
                    if dataname in prim.parent.transients:
                        # This is a local transient variable, so connect to it
                        # directly
                        out_tgt = output_node_for_array(state, data.dataname)
                        _add_astmemlet_edge(sdfg, state, out_src,
                                            astmem.local_name, out_tgt, None,
                                            astmem)
                    else:
                        # This is an I/O array, or an outer level transient, so
                        # redirect through the exit node
                        _add_astmemlet_edge(sdfg,
                                            state,
                                            out_src,
                                            astmem.local_name,
                                            super_exit,
                                            None,
                                            astmem,
                                            wcr=astmem.wcr,
                                            wcr_identity=astmem.wcr_identity)
                        exit_nodes.append((super_exit, astmem))
                        if astmem.wcr is not None:
                            wcr_was_handled = True  # Sanity check
                else:
                    raise TypeError("Unexpected node type: {}".format(
                        type(out_src).__name__))

            if not wcr_was_handled and not isinstance(prim,
                                                      astnodes._ScopeNode):
                raise RuntimeError("Detected unhandled WCR for primitive '{}' "
                                   "of type {}. WCR is only expected for "
                                   "tasklets in a map/consume scope.".format(
                                       prim.name,
                                       type(prim).__name__))

    return exit_nodes
Exemple #10
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')
    _reduce = nodes.Reduce('lambda: None', None)

    @staticmethod
    def expressions():
        return [
            # Map, then reduce of all axes
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._rmap_in_entry,
                MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr,
                MapReduceFusion._out_array),
            # Map, then partial reduction of axes
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._rmap_out_entry,
                MapReduceFusion._rmap_in_entry,
                MapReduceFusion._rmap_in_tasklet, MapReduceFusion._rmap_in_cr,
                MapReduceFusion._rmap_out_exit, MapReduceFusion._out_array),
            # Map, then reduce node
            nxutil.node_path_graph(
                MapReduceFusion._tasklet, MapReduceFusion._tmap_exit,
                MapReduceFusion._in_array, MapReduceFusion._reduce,
                MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[candidate[
                MapReduceFusion._rmap_in_entry]]
            # rmap_in_entry = rmap_entry
        elif expr_index == 1:  # Reduce with outer map
            rmap_entry = graph.nodes()[candidate[
                MapReduceFusion._rmap_out_entry]]
            # rmap_in_entry = graph.nodes()[candidate[
            #     MapReduceFusion._rmap_in_entry]]
        else:  # Reduce node
            rmap_entry = graph.nodes()[candidate[MapReduceFusion._reduce]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        if expr_index < 2:
            rmap_cr = graph.nodes()[candidate[MapReduceFusion._rmap_in_cr]]
            reduce_edge = graph.in_edges(rmap_cr)[0]
            if reduce_edge.data.wcr is None:
                return False

        # Make sure that the transient is not accessed by other states
        # if garr.get_unique_name() in cgen_state.sdfg.shared_transients():
        #     return False

        # reduce_inarr = reduce.in_array
        # reduce_outarr = reduce.out_array
        # reduce_inslice = reduce.inslice
        # reduce_outslice = reduce.outslice

        # insize = cgen_state.var_sizes[reduce_inarr]
        # outsize = cgen_state.var_sizes[reduce_outarr]

        # Currently only supports full-range arrays
        # TODO(later): Support fusion of partial reductions and refactor slice/subarray handling
        #if not nxutil.fullrange(reduce_inslice, insize) or \
        #   not nxutil.fullrange(reduce_outslice, outsize):
        #    return False

        # Verify acceses from tasklet through MapExit
        #already_found = False
        #for _src, _, _dest, _, memlet in graph.in_edges(map_exit):
        #    if isinstance(memlet.subset, subsets.Indices):
        #        # Make sure that only one value is reduced at a time
        #        if memlet.data == in_array.desc:
        #            if already_found:
        #                return False
        #            already_found = True

        ## Find axes after reduction
        #indims = len(reduce.inslice)
        #axis_after_reduce = [None] * indims
        #ctr = 0
        #for i in range(indims):
        #    if reduce.axes is not None and i in reduce.axes:
        #        axis_after_reduce[i] = None
        #    else:
        #        axis_after_reduce[i] = ctr
        #        ctr += 1

        ## Match map ranges with reduce ranges
        #curaxis = 0
        #for dim, var in enumerate(memlet.subset):
        #    # Make sure that indices are direct symbols
        #    #if not isinstance(symbolic.pystr_to_symbolic(var), sympy.Symbol):
        #    #    return False
        #    perm = None
        #    for i, mapvar in enumerate(map_exit.map.params):
        #        if symbolic.pystr_to_symbolic(mapvar) == var:
        #            perm = i
        #            break
        #    if perm is None:  # If symbol is not found in map range
        #        return False

        #    # Make sure that map ranges match output slice after reduction
        #    map_range = map_exit.map.range[perm]
        #    if map_range[0] != 0:
        #        return False  # Disallow start from middle
        #    if map_range[2] is not None and map_range[2] != 1:
        #        return False  # Disallow skip
        #    if reduce.axes is not None and dim not in reduce.axes:
        #        if map_range[1] != symbolic.pystr_to_symbolic(
        #                reduce.outslice[axis_after_reduce[dim]][1]):
        #            return False  # Range check (output axis)
        #    else:
        #        if map_range[1] != symbolic.pystr_to_symbolic(reduce.inslice[dim][1]):
        #            return False  # Range check (reduction axis)

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        if len(candidate) == 5:  # Expression 2
            reduce = candidate[MapReduceFusion._reduce]
        else:
            reduce = candidate[MapReduceFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    @staticmethod
    def find_memlet_map_permutation(memlet: Memlet, map: nodes.Map):
        perm = [None] * len(memlet.subset)
        indices = set()
        for i, dim in enumerate(memlet.subset):
            for j, mapdim in enumerate(map.params):
                if symbolic.pystr_to_symbolic(
                        mapdim) == dim and j not in indices:
                    perm[i] = j
                    indices.add(j)
                    break
        return perm

    @staticmethod
    def find_permutation(tasklet_map: nodes.Map, red_outer_map: nodes.Map,
                         red_inner_map: nodes.Map, tmem: Memlet):
        """ Find permutation between tasklet-exit memlet and tasklet map. """
        result = [], []

        assert len(tasklet_map.range) == len(red_inner_map.range) + len(
            red_outer_map.range)

        # Match map ranges with reduce ranges
        unavailable_ranges_out = set()
        unavailable_ranges_in = set()
        for i, tmap_rng in enumerate(tasklet_map.range):
            found = False
            for j, rng in enumerate(red_outer_map.range):
                if tmap_rng == rng and j not in unavailable_ranges_out:
                    result[0].append(i)
                    unavailable_ranges_out.add(j)
                    found = True
                    break
            if found: continue
            for j, rng in enumerate(red_inner_map.range):
                if tmap_rng == rng and j not in unavailable_ranges_in:
                    result[1].append(i)
                    unavailable_ranges_in.add(j)
                    found = True
                    break
            if not found: break

        # Ensure all map variables matched with reduce variables
        assert len(result[0]) + len(result[1]) == len(tasklet_map.range)

        # Returns ([outer map indices], [inner (CR) map indices])
        return result

    @staticmethod
    def find_permutation_reduce(tasklet_map: nodes.Map,
                                reduce_node: nodes.Reduce, graph: SDFGState,
                                tmem: Memlet):

        in_memlet = graph.in_edges(reduce_node)[0].data
        out_memlet = graph.out_edges(reduce_node)[0].data
        assert len(tasklet_map.range) == in_memlet.subset.dims()

        # Find permutation between tasklet-exit memlet and tasklet map
        tmem_perm = MapReduceFusion.find_memlet_map_permutation(
            tmem, tasklet_map)
        mapred_perm = []

        # Match map ranges with reduce ranges
        unavailable_ranges = set()
        for i, tmap_rng in enumerate(tasklet_map.range):
            found = False

            for j, in_rng in enumerate(in_memlet.subset):
                if tmap_rng == in_rng and j not in unavailable_ranges:
                    mapred_perm.append(i)
                    unavailable_ranges.add(j)
                    found = True
                    break
            if not found: break

        # Ensure all map variables matched with reduce variables
        assert len(tmem_perm) == len(tmem.subset)
        assert len(mapred_perm) == len(in_memlet.subset)

        # Prepare result from the two permutations and the reduction axes
        result = []
        for i in range(len(mapred_perm)):
            if reduce_node.axes is None or i in reduce_node.axes:
                continue
            result.append(mapred_perm[tmem_perm[i]])

        return result

    def apply(self, sdfg):
        def gnode(nname):
            return graph.nodes()[self.subgraph[nname]]

        expr_index = self.expr_index
        graph = sdfg.nodes()[self.state_id]
        tasklet = gnode(MapReduceFusion._tasklet)
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        if expr_index == 0:  # Reduce without outer map
            rmap_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
        elif expr_index == 1:  # Reduce with outer map
            rmap_out_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_entry]]
            rmap_out_exit = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_out_exit]]
            rmap_in_entry = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_entry]]
            rmap_tasklet = graph.nodes()[self.subgraph[
                MapReduceFusion._rmap_in_tasklet]]

        if expr_index == 2:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        else:
            rmap_cr = graph.nodes()[self.subgraph[MapReduceFusion._rmap_in_cr]]
        out_array = gnode(MapReduceFusion._out_array)

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        if expr_index == 0:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
        elif expr_index == 1:
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_in_entry))
            nodes_to_remove.append(gnode(MapReduceFusion._rmap_out_exit))
        else:
            nodes_to_remove.append(gnode(MapReduceFusion._reduce))

        # If no other edges lead to mapexit, remove it. Otherwise, keep
        # it and remove reduction incoming/outgoing edges
        if expr_index != 2 and len(graph.in_edges(tmap_exit)) == 1:
            nodes_to_remove.append(tmap_exit)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        if expr_index == 0:  # Reduce without outer map
            # Index order does not matter, merge as-is
            pass
        elif expr_index == 1:  # Reduce with outer map
            tmap = tmap_exit.map
            perm_outer, perm_inner = MapReduceFusion.find_permutation(
                tmap, rmap_out_entry.map, rmap_in_entry.map, memlet_edge.data)

            # Split tasklet map into tmap_out -> tmap_in (according to
            # reduction)
            omap = nodes.Map(
                tmap.label + '_nonreduce',
                [p for i, p in enumerate(tmap.params) if i in perm_outer],
                [r for i, r in enumerate(tmap.range) if i in perm_outer],
                tmap.schedule, tmap.unroll, tmap.is_async)
            tmap.params = [
                p for i, p in enumerate(tmap.params) if i in perm_inner
            ]
            tmap.range = [
                r for i, r in enumerate(tmap.range) if i in perm_inner
            ]
            omap_entry = nodes.MapEntry(omap)
            omap_exit = rmap_out_exit
            rmap_out_exit.map = omap

            # Reconnect graph to new map
            tmap_entry = graph.entry_node(tmap_exit)
            tmap_in_edges = list(graph.in_edges(tmap_entry))
            for e in tmap_in_edges:
                nxutil.change_edge_dest(graph, tmap_entry, omap_entry)
            for e in tmap_in_edges:
                graph.add_edge(omap_entry, e.src_conn, tmap_entry, e.dst_conn,
                               copy.copy(e.data))
        elif expr_index == 2:  # Reduce node
            # Find correspondence between map indices and array outputs
            tmap = tmap_exit.map
            perm = MapReduceFusion.find_permutation_reduce(
                tmap, rmap_cr, graph, memlet_edge.data)

            output_subset = [tmap.params[d] for d in perm]
            if len(output_subset) == 0:  # Output is a scalar
                output_subset = [0]

            array_edge = graph.out_edges(rmap_cr)[0]

            # Delete relevant edges and nodes
            graph.remove_edge(memlet_edge)
            graph.remove_nodes_from(nodes_to_remove)

            # Add new edges and nodes
            #   From tasklet to map exit
            graph.add_edge(
                memlet_edge.src, memlet_edge.src_conn, memlet_edge.dst,
                memlet_edge.dst_conn,
                Memlet(out_array.data, memlet_edge.data.num_accesses,
                       subsets.Indices(output_subset), memlet_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            #   From map exit to output array
            graph.add_edge(
                memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:],
                array_edge.dst, array_edge.dst_conn,
                Memlet(array_edge.data.data, array_edge.data.num_accesses,
                       array_edge.data.subset, array_edge.data.veclen,
                       rmap_cr.wcr, rmap_cr.identity))

            return

        # Remove tmp array node prior to the others, so that a new one
        # can be created in its stead (see below)
        graph.remove_node(nodes_to_remove[0])
        nodes_to_remove = nodes_to_remove[1:]

        # Create tasklet -> tmp -> tasklet connection
        tmp = graph.add_array(
            'tmp',
            memlet_edge.data.subset.bounding_box_size(),
            sdfg.arrays[memlet_edge.data.data].dtype,
            transient=True)
        tasklet_tmp_memlet = copy.deepcopy(memlet_edge.data)
        tasklet_tmp_memlet.data = tmp.data
        tasklet_tmp_memlet.subset = ShapeProperty.to_string(tmp.shape)

        # Modify memlet to point to output array
        memlet_edge.data.data = out_array.data

        # Recover reduction axes from CR reduce subset
        reduce_cr_subset = graph.in_edges(rmap_tasklet)[0].data.subset
        reduce_axes = []
        for ind, crvar in enumerate(reduce_cr_subset.indices):
            if '__i' in str(crvar):
                reduce_axes.append(ind)

        # Modify memlet access index by filtering out reduction axes
        if True:  # expr_index == 0:
            newindices = []
            for ind, ovar in enumerate(memlet_edge.data.subset.indices):
                if ind not in reduce_axes:
                    newindices.append(ovar)
        if len(newindices) == 0:
            newindices = [0]

        memlet_edge.data.subset = subsets.Indices(newindices)

        graph.remove_edge(memlet_edge)

        graph.add_edge(memlet_edge.src, memlet_edge.src_conn, tmp,
                       memlet_edge.dst_conn, tasklet_tmp_memlet)

        red_edges = list(graph.in_edges(rmap_tasklet))
        if len(red_edges) != 1:
            raise RuntimeError('CR edge must be unique')

        tmp_tasklet_memlet = copy.deepcopy(tasklet_tmp_memlet)
        graph.add_edge(tmp, None, rmap_tasklet, red_edges[0].dst_conn,
                       tmp_tasklet_memlet)

        for e in graph.edges_between(rmap_tasklet, rmap_cr):
            e.data.subset = memlet_edge.data.subset

        # Move output edges to point directly to CR node
        if expr_index == 1:
            # Set output memlet between CR node and outer reduction map to
            # contain the same subset as the one pointing to the CR node
            for e in graph.out_edges(rmap_cr):
                e.data.subset = memlet_edge.data.subset

            rmap_out = gnode(MapReduceFusion._rmap_out_exit)
            nxutil.change_edge_src(graph, rmap_out, omap_exit)

        # Remove nodes
        graph.remove_nodes_from(nodes_to_remove)

        # For unrelated outputs, connect original output to rmap_out
        if expr_index == 1 and tmap_exit not in nodes_to_remove:
            other_out_edges = list(graph.out_edges(tmap_exit))
            for e in other_out_edges:
                graph.remove_edge(e)
                graph.add_edge(e.src, e.src_conn, omap_exit, None, e.data)
                graph.add_edge(omap_exit, None, e.dst, e.dst_conn,
                               copy.copy(e.data))

    def modifies_graph(self):
        return True
Exemple #11
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if memlet.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[memlet.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            if memlet.wcr is not None:
                return False

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(
            str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, _dim_step = map_entry.map.range[-1]
        map_entry.map.range[-1] = (dim_from, dim_to, vector_size)

        # Vectorize memlets adjacent to the tasklet.
        for edge in graph.all_edges(tasklet):
            _src, _, _dest, _, memlet = edge

            if memlet.data is None:  # Empty memlets
                continue

            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols

            if param in symbols:
                try:
                    # propagate vector length inside this SDFG
                    for e in graph.memlet_path(edge):
                        e.data.veclen = vector_size

                    source_edge = graph.memlet_path(edge)[0]
                    sink_edge = graph.memlet_path(edge)[-1]

                    # propagate to the parent (TODO: handle multiple level of nestings)
                    if self.propagate_parent and sdfg.parent is not None:
                        # Find parent Nested SDFG node
                        parent_node = next(n for n in sdfg.parent.nodes()
                                           if isinstance(n, nodes.NestedSDFG)
                                           and n.sdfg.name == sdfg.name)

                        # continue in propagating the vector length following the path that arrives to source_edge or
                        # starts from sink_edge
                        for pe in sdfg.parent.all_edges(parent_node):
                            if str(pe.dst_conn) == str(source_edge.src) or str(
                                    pe.src_conn) == str(sink_edge.dst):
                                for ppe in sdfg.parent.memlet_path(pe):
                                    ppe.data.veclen = vector_size

                except AttributeError:
                    raise
        return
Exemple #12
0
class AccumulateTransient(pattern_matching.Transformation):
    """ Implements the AccumulateTransient transformation, which adds
        transient stream and data nodes between nested maps that lead to a 
        stream. The transient data nodes then act as a local accumulator.
    """

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    array = Property(
        dtype=str,
        desc="Array to create local storage for (if empty, first available)",
        default=None,
        allow_none=True)

    @staticmethod
    def expressions():
        return [
            nxutil.node_path_graph(AccumulateTransient._tasklet,
                                   AccumulateTransient._map_exit,
                                   AccumulateTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        tasklet = graph.nodes()[candidate[AccumulateTransient._tasklet]]
        map_exit = graph.nodes()[candidate[AccumulateTransient._map_exit]]

        # Check if there is an accumulation output
        for _src, _, dest, _, memlet in graph.out_edges(tasklet):
            if memlet.wcr is not None and dest == map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[AccumulateTransient._tasklet]
        map_exit = candidate[AccumulateTransient._map_exit]
        outer_map_exit = candidate[AccumulateTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import LocalStorage

        local_storage_subgraph = {
            LocalStorage._node_a:
            self.subgraph[AccumulateTransient._map_exit],
            LocalStorage._node_b:
            self.subgraph[AccumulateTransient._outer_map_exit]
        }
        sdfg_id = sdfg.sdfg_list.index(sdfg)
        in_local_storage = LocalStorage(
            sdfg_id, self.state_id, local_storage_subgraph, self.expr_index)
        in_local_storage.array = self.array
        in_local_storage.apply(sdfg)

        # Initialize transient to zero in case of summation
        # TODO: Initialize transient in other WCR types
        memlet = graph.in_edges(in_local_storage._data_node)[0].data
        if detect_reduction_type(memlet.wcr) == dtypes.ReductionType.Sum:
            in_local_storage._data_node.setzero = True
        else:
            warnings.warn('AccumulateTransient did not properly initialize'
                          'newly-created transient!')
Exemple #13
0
def my_add_mapped_tasklet(
    state,
    inpdict,
    outdict,
    name: str,
    map_ranges: Dict[str, dace.subsets.Subset],
    inputs: Dict[str, dace.memlet.Memlet],
    code: str,
    outputs: Dict[str, dace.memlet.Memlet],
    schedule=dace.dtypes.ScheduleType.Default,
    unroll_map=False,
    code_global="",
    code_init="",
    code_exit="",
    location="-1",
    language=dace.dtypes.Language.Python,
    debuginfo=None,
    external_edges=True,
) -> Tuple[dace.graph.nodes.Node]:
    """ Convenience function that adds a map entry, tasklet, map exit,
        and the respective edges to external arrays.
        :param name:       Tasklet (and wrapping map) name
        :param map_ranges: Mapping between variable names and their
                           subsets
        :param inputs:     Mapping between input local variable names and
                           their memlets
        :param code:       Code (written in `language`)
        :param outputs:    Mapping between output local variable names and
                           their memlets
        :param schedule:   Map schedule
        :param unroll_map: True if map should be unrolled in code
                           generation
        :param code_global: (optional) Global code (outside functions)
        :param language:   Programming language in which the code is
                           written
        :param debuginfo:  Debugging information (mostly for DIODE)
        :param external_edges: Create external access nodes and connect
                               them with memlets automatically

        :return: tuple of (tasklet, map_entry, map_exit)
    """
    import dace.graph.nodes as nd
    from dace.sdfg import getdebuginfo
    from dace.graph.labeling import propagate_memlet
    map_name = name + "_map"
    debuginfo = getdebuginfo(debuginfo)
    tasklet = nd.Tasklet(
        name,
        set(inputs.keys()),
        set(outputs.keys()),
        code,
        language=language,
        code_global=code_global,
        code_init=code_init,
        code_exit=code_exit,
        location=location,
        debuginfo=debuginfo,
    )
    map = state._map_from_ndrange(map_name,
                                  schedule,
                                  unroll_map,
                                  map_ranges,
                                  debuginfo=debuginfo)
    map_entry = nd.MapEntry(map)
    map_exit = nd.MapExit(map)
    state.add_nodes_from([map_entry, tasklet, map_exit])

    tomemlet = {}
    for name, memlet in inputs.items():

        memlet.name = name

        state.add_edge(map_entry, None, tasklet, name, memlet)
        tomemlet[memlet.data] = memlet

    if len(inputs) == 0:
        state.add_edge(map_entry, None, tasklet, None,
                       dace.memlet.EmptyMemlet())

    if external_edges:
        for inp, inpnode in inpdict.items():

            outer_memlet = propagate_memlet(state, tomemlet[inp], map_entry,
                                            True)
            state.add_edge(inpnode, None, map_entry, "IN_" + inp, outer_memlet)

            for e in state.out_edges(map_entry):
                if e.data.data == inp:
                    e._src_conn = "OUT_" + inp

            map_entry.add_in_connector("IN_" + inp)
            map_entry.add_out_connector("OUT_" + inp)

    tomemlet = {}
    for name, memlet in outputs.items():

        memlet.name = name

        state.add_edge(tasklet, name, map_exit, None, memlet)
        tomemlet[memlet.data] = memlet

    if len(outputs) == 0:
        state.add_edge(tasklet, None, map_exit, None, mm.EmptyMemlet())

    if external_edges:
        for out, outnode in outdict.items():

            outer_memlet = propagate_memlet(state, tomemlet[out], map_exit,
                                            True)
            state.add_edge(map_exit, "OUT_" + out, outnode, None, outer_memlet)

            for e in state.in_edges(map_exit):
                if e.data.data == out:
                    e._dst_conn = "IN_" + out

            map_exit.add_in_connector("IN_" + out)
            map_exit.add_out_connector("OUT_" + out)

    return tasklet, map_entry, map_exit