示例#1
0
class Vectorization(pattern_matching.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)
    strided_map = Property(desc="Use strided map range (jump by vector length)"
                           " instead of modifying memlets",
                           dtype=bool,
                           default=False)

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their last dimension.
        for _src, _, _dest, _, memlet in graph.all_edges(tasklet):

            # Cases that do not matter for vectorization
            if memlet.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[memlet.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            if memlet.wcr is not None:
                return False

            try:
                subset = memlet.subset
                veclen = memlet.veclen
            except AttributeError:
                return False

            if subset is None:
                return False

            try:
                if veclen > symbolic.pystr_to_symbolic('1'):
                    return False

                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if idx == subset.dims() - 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if idx == subset.dims() - 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        map_exit = graph.nodes()[self.subgraph[Vectorization._map_exit]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len

        # Change the step of the inner-most dimension.
        dim_from, dim_to, dim_step = map_entry.map.range[-1]
        if self.strided_map:
            map_entry.map.range[-1] = (dim_from, dim_to, vector_size)
        else:
            map_entry.map.range[-1] = (dim_from, (dim_to + 1) / vector_size - 1,
                                       dim_step)

        # TODO: Postamble and/or preamble non-vectorized map

        # Vectorize memlets adjacent to the tasklet.
        processed_edges = set()
        for edge in graph.all_edges(tasklet):
            _src, _, _dest, _, memlet = edge

            if memlet.data is None:  # Empty memlets
                continue

            lastindex = memlet.subset[-1]
            if isinstance(lastindex, tuple):
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                symbols = symbolic.pystr_to_symbolic(
                    memlet.subset[-1]).free_symbols

            if param not in symbols:
                continue
            try:
                # propagate vector length inside this SDFG
                for e in graph.memlet_tree(edge):
                    e.data.veclen = vector_size
                    if not self.strided_map and e not in processed_edges:
                        e.data.subset.replace({param: vector_size * param})
                        processed_edges.add(e)

                # propagate to the parent (TODO: handle multiple level of nestings)
                if self.propagate_parent and sdfg.parent is not None:
                    source_edge = graph.memlet_path(edge)[0]
                    sink_edge = graph.memlet_path(edge)[-1]

                    # Find parent Nested SDFG node
                    parent_node = next(n for n in sdfg.parent.nodes()
                                       if isinstance(n, nodes.NestedSDFG)
                                       and n.sdfg.name == sdfg.name)

                    # continue in propagating the vector length following the
                    # path that arrives to source_edge or starts from sink_edge
                    for pe in sdfg.parent.all_edges(parent_node):
                        if str(pe.dst_conn) == str(source_edge.src) or str(
                                pe.src_conn) == str(sink_edge.dst):
                            for ppe in sdfg.parent.memlet_tree(pe):
                                ppe.data.veclen = vector_size
                                if (not self.strided_map
                                        and ppe not in processed_edges):
                                    ppe.data.subset.replace(
                                        {param: vector_size * param})
                                    processed_edges.add(ppe)

            except AttributeError:
                raise
        return
示例#2
0
class Vectorization(transformation.Transformation):
    """ Implements the vectorization transformation.

        Vectorization matches when all the input and output memlets of a 
        tasklet inside a map access the inner-most loop variable in their last
        dimension. The transformation changes the step of the inner-most loop
        to be equal to the length of the vector and vectorizes the memlets.
  """

    vector_len = Property(desc="Vector length", dtype=int, default=4)
    propagate_parent = Property(desc="Propagate vector length through "
                                "parent SDFGs",
                                dtype=bool,
                                default=False)
    strided_map = Property(desc="Use strided map range (jump by vector length)"
                           " instead of modifying memlets",
                           dtype=bool,
                           default=True)
    preamble = Property(
        dtype=bool,
        default=None,
        allow_none=True,
        desc='Force creation or skipping a preamble map without vectors')
    postamble = Property(
        dtype=bool,
        default=None,
        allow_none=True,
        desc='Force creation or skipping a postamble map without vectors')

    _map_entry = nodes.MapEntry(nodes.Map("", [], []))
    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(Vectorization._map_entry,
                                   Vectorization._tasklet,
                                   Vectorization._map_exit)
        ]

    def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False):
        map_entry = graph.nodes()[candidate[Vectorization._map_entry]]
        tasklet = graph.nodes()[candidate[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])
        found = False

        # Strided maps cannot be vectorized
        if map_entry.map.range[-1][2] != 1 and self.strided_map:
            return False

        # Check if all edges, adjacent to the tasklet,
        # use the parameter in their contiguous dimension.
        for e, conntype in graph.all_edges_and_connectors(tasklet):

            # Cases that do not matter for vectorization
            if e.data.data is None:  # Empty memlets
                continue
            if isinstance(sdfg.arrays[e.data.data], data.Stream):  # Streams
                continue

            # Vectorization can not be applied in WCR
            # if e.data.wcr is not None:
            #     return False

            subset = e.data.subset
            array = sdfg.arrays[e.data.data]

            # If already vectorized or a pointer, do not apply
            if isinstance(conntype, (dtypes.vector, dtypes.pointer)):
                return False

            try:
                for idx, expr in enumerate(subset):
                    if isinstance(expr, tuple):
                        for ex in expr:
                            ex = symbolic.pystr_to_symbolic(ex)
                            symbols = ex.free_symbols
                            if param in symbols:
                                if array.strides[idx] == 1:
                                    found = True
                                else:
                                    return False
                    else:
                        expr = symbolic.pystr_to_symbolic(expr)
                        symbols = expr.free_symbols
                        if param in symbols:
                            if array.strides[idx] == 1:
                                found = True
                            else:
                                return False
            except TypeError:  # cannot determine truth value of Relational
                return False

        return found

    @staticmethod
    def match_to_str(graph, candidate):

        map_entry = candidate[Vectorization._map_entry]
        tasklet = candidate[Vectorization._tasklet]
        map_exit = candidate[Vectorization._map_exit]

        return ' -> '.join(str(node) for node in [map_entry, tasklet, map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[Vectorization._map_entry]]
        tasklet = graph.nodes()[self.subgraph[Vectorization._tasklet]]
        param = symbolic.pystr_to_symbolic(map_entry.map.params[-1])

        # Create new vector size.
        vector_size = self.vector_len
        dim_from, dim_to, dim_skip = map_entry.map.range[-1]

        # Determine whether to create preamble or postamble maps
        if self.preamble is not None:
            create_preamble = self.preamble
        else:
            create_preamble = not ((dim_from % vector_size == 0) == True
                                   or dim_from == 0)
        if self.postamble is not None:
            create_postamble = self.postamble
        else:
            if isinstance(dim_to, symbolic.SymExpr):
                create_postamble = (((dim_to.approx + 1) %
                                     vector_size == 0) == False)
            else:
                create_postamble = (((dim_to + 1) % vector_size == 0) == False)

        # Determine new range for vectorized map
        if self.strided_map:
            new_range = [dim_from, dim_to - vector_size + 1, vector_size]
        else:
            new_range = [
                dim_from // vector_size, ((dim_to + 1) // vector_size) - 1,
                dim_skip
            ]

        # Create preamble non-vectorized map (replacing the original map)
        if create_preamble:
            old_scope = graph.scope_subgraph(map_entry, True, True)
            new_scope: ScopeSubgraphView = replicate_scope(
                sdfg, graph, old_scope)
            new_begin = dim_from + (vector_size - (dim_from % vector_size))
            map_entry.map.range[-1] = (dim_from, new_begin - 1, dim_skip)
            # Replace map_entry with the replicated scope (so that the preamble
            # will usually come first in topological sort)
            map_entry = new_scope.entry
            tasklet = new_scope.nodes()[old_scope.nodes().index(tasklet)]
            new_range[0] = new_begin

        # Create postamble non-vectorized map
        if create_postamble:
            new_scope: ScopeSubgraphView = replicate_scope(
                sdfg, graph, graph.scope_subgraph(map_entry, True, True))
            dim_to_ex = dim_to + 1
            new_scope.entry.map.range[-1] = (dim_to_ex -
                                             (dim_to_ex % vector_size), dim_to,
                                             dim_skip)

        # Change the step of the inner-most dimension.
        map_entry.map.range[-1] = tuple(new_range)

        # Vectorize connectors adjacent to the tasklet.
        for edge in graph.all_edges(tasklet):
            connectors = (tasklet.in_connectors
                          if edge.dst == tasklet else tasklet.out_connectors)
            conn = edge.dst_conn if edge.dst == tasklet else edge.src_conn

            if edge.data.data is None:  # Empty memlets
                continue
            desc = sdfg.arrays[edge.data.data]
            contigidx = desc.strides.index(1)

            newlist = []

            lastindex = edge.data.subset[contigidx]
            if isinstance(lastindex, tuple):
                newlist = [(rb, re, rs) for rb, re, rs in edge.data.subset]
                symbols = set()
                for indd in lastindex:
                    symbols.update(
                        symbolic.pystr_to_symbolic(indd).free_symbols)
            else:
                newlist = [(rb, rb, 1) for rb in edge.data.subset]
                symbols = symbolic.pystr_to_symbolic(lastindex).free_symbols

            oldtype = connectors[conn]
            if oldtype is None or oldtype.type is None:
                oldtype = desc.dtype

            # Vector to scalar WCR edge: change connector and continue
            if (edge.data.subset.num_elements() == 1
                    and edge.data.wcr is not None):
                connectors[conn] = dtypes.vector(oldtype, vector_size)
                continue

            if str(param) not in map(str, symbols):
                continue

            # Vectorize connector, if not already vectorized
            if isinstance(oldtype, dtypes.vector):
                continue

            connectors[conn] = dtypes.vector(oldtype, vector_size)

            # Modify memlet subset to match vector length
            if self.strided_map:
                rb = newlist[contigidx][0]
                if self.propagate_parent:
                    newlist[contigidx] = (rb / self.vector_len,
                                          rb / self.vector_len, 1)
                else:
                    newlist[contigidx] = (rb, rb + self.vector_len - 1, 1)
            else:
                rb = newlist[contigidx][0]
                if self.propagate_parent:
                    newlist[contigidx] = (rb, rb, 1)
                else:
                    newlist[contigidx] = (self.vector_len * rb,
                                          self.vector_len * rb +
                                          self.vector_len - 1, 1)
            edge.data.subset = subsets.Range(newlist)
            edge.data.volume = vector_size

        # Vector length propagation using data descriptors, recursive traversal
        # outwards
        if self.propagate_parent:
            for edge in graph.all_edges(tasklet):
                cursdfg = sdfg
                curedge = edge
                while cursdfg is not None:
                    arrname = curedge.data.data
                    dtype = cursdfg.arrays[arrname].dtype

                    # Change type and shape to vector
                    if not isinstance(dtype, dtypes.vector):
                        cursdfg.arrays[arrname].dtype = dtypes.vector(
                            dtype, vector_size)
                        new_shape = list(cursdfg.arrays[arrname].shape)
                        contigidx = cursdfg.arrays[arrname].strides.index(1)
                        new_shape[contigidx] /= vector_size
                        try:
                            new_shape[contigidx] = int(new_shape[contigidx])
                        except TypeError:
                            pass
                        cursdfg.arrays[arrname].shape = new_shape

                    propagation.propagate_memlets_sdfg(cursdfg)

                    # Find matching edge in parent
                    nsdfg = cursdfg.parent_nsdfg_node
                    if nsdfg is None:
                        break
                    tstate = cursdfg.parent
                    curedge = ([
                        e
                        for e in tstate.in_edges(nsdfg) if e.dst_conn == arrname
                    ] + [
                        e for e in tstate.out_edges(nsdfg)
                        if e.src_conn == arrname
                    ])[0]
                    cursdfg = cursdfg.parent_sdfg
示例#3
0
 def __init__(self, *args, **kwargs):
     self._entry = nodes.EntryNode()
     self._tasklet = nodes.Tasklet('_')
     self._exit = nodes.ExitNode()
     super().__init__(*args, **kwargs)
示例#4
0
class StreamTransient(transformation.Transformation):
    """ Implements the StreamTransient transformation, which adds a transient
        and stream nodes between nested maps that lead to a stream. The
        transient then acts as a local buffer.
    """

    with_buffer = Property(dtype=bool,
                           default=True,
                           desc="Use an intermediate buffer for accumulation")

    _tasklet = nodes.Tasklet('_')
    _map_exit = nodes.MapExit(nodes.Map("", [], []))
    _outer_map_exit = nodes.MapExit(nodes.Map("", [], []))

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(StreamTransient._tasklet,
                                   StreamTransient._map_exit,
                                   StreamTransient._outer_map_exit)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        map_exit = graph.nodes()[candidate[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[candidate[
            StreamTransient._outer_map_exit]]

        # Check if there is a streaming output
        for _src, _, dest, _, memlet in graph.out_edges(map_exit):
            if isinstance(sdfg.arrays[memlet.data],
                          data.Stream) and dest == outer_map_exit:
                return True

        return False

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[StreamTransient._tasklet]
        map_exit = candidate[StreamTransient._map_exit]
        outer_map_exit = candidate[StreamTransient._outer_map_exit]

        return ' -> '.join(
            str(node) for node in [tasklet, map_exit, outer_map_exit])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        tasklet = graph.nodes()[self.subgraph[StreamTransient._tasklet]]
        map_exit = graph.nodes()[self.subgraph[StreamTransient._map_exit]]
        outer_map_exit = graph.nodes()[self.subgraph[
            StreamTransient._outer_map_exit]]
        memlet = None
        edge = None
        for e in graph.out_edges(map_exit):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data],
                                                      data.Stream):
                edge = e
                break
        tasklet_memlet = None
        for e in graph.out_edges(tasklet):
            tasklet_memlet = e.data
            if tasklet_memlet.data == memlet.data:
                break

        bbox = map_exit.map.range.bounding_box_size()
        bbox_approx = [symbolic.overapproximate(dim) for dim in bbox]
        dataname = memlet.data

        # Create the new node: Temporary stream and an access node
        newname, _ = sdfg.add_stream('trans_' + dataname,
                                     sdfg.arrays[memlet.data].dtype,
                                     1,
                                     bbox_approx[0], [1],
                                     transient=True,
                                     find_new_name=True)
        snode = graph.add_access(newname)

        to_stream_mm = copy.deepcopy(memlet)
        to_stream_mm.data = snode.data
        tasklet_memlet.data = snode.data

        if self.with_buffer:
            newname_arr, _ = sdfg.add_transient('strans_' + dataname,
                                                [bbox_approx[0]],
                                                sdfg.arrays[memlet.data].dtype,
                                                find_new_name=True)
            anode = graph.add_access(newname_arr)
            to_array_mm = copy.deepcopy(memlet)
            to_array_mm.data = anode.data
            graph.add_edge(snode, None, anode, None, to_array_mm)
        else:
            anode = snode

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, edge.src_conn, snode, None, to_stream_mm)
        graph.add_edge(anode, None, outer_map_exit, edge.dst_conn, memlet)

        return
示例#5
0
    def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
        map_entry = self.map_entry
        map_exit = graph.exit_node(map_entry)

        sz = dace.symbol('commsize',
                         dtype=dace.int32,
                         integer=True,
                         positive=True)
        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)

        from dace.data import _prod

        # NOTE: Maps with step in their ranges are currently not supported
        if len(map_entry.map.params) == 2:
            params = map_entry.map.params
            ranges = [None] * 2
            b, e, _ = map_entry.map.range[0]
            ranges[0] = (0, (e - b + 1) / Px - 1, 1)
            b, e, _ = map_entry.map.range[1]
            ranges[1] = (0, (e - b + 1) / Py - 1, 1)
            strides = [1]
        else:
            params = ['__iflat']
            sizes = map_entry.map.range.size_exact()
            total_size = _prod(sizes)
            ranges = [(0, (total_size) / sz - 1, 1)]
            strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))]

        root_name = sdfg.temp_data_name()
        sdfg.add_scalar(root_name, dace.int32, transient=True)
        root_node = graph.add_access(root_name)
        root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = 0')
        graph.add_edge(root_tasklet, '__out', root_node, None,
                       dace.Memlet.simple(root_name, '0'))

        from dace.libraries.mpi import Bcast
        from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather

        inputs = set()
        for src, _, _, _, m in graph.in_edges(map_entry):
            if not isinstance(src, nodes.AccessNode):
                raise NotImplementedError
            desc = src.desc(sdfg)
            if not isinstance(desc, (data.Scalar, data.Array)):
                raise NotImplementedError
            if list(desc.shape) != m.src_subset.size_exact():
                # Second attempt
                # TODO: We need a solution for symbols not matching
                if str(list(desc.shape)) != str(m.src_subset.size_exact()):
                    raise NotImplementedError
            inputs.add(src)

        for inp in inputs:
            desc = inp.desc(sdfg)

            if isinstance(desc, data.Scalar):
                local_access = graph.add_access(inp.data)
                bcast_node = Bcast('_Bcast_')
                graph.add_edge(inp, None, bcast_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, bcast_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(bcast_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(inp.data, desc))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(local_access, None, map_entry, e.dst_conn,
                                   dace.Memlet.from_array(inp.data, desc))
                    graph.remove_edge(e)

            elif isinstance(desc, data.Array):

                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                gdesc_name, gdesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                gdesc_access = graph.add_access(gdesc_name)
                ldesc_name, ldesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                ldesc_access = graph.add_access(ldesc_name)
                scatter_node = BlockCyclicScatter('_Scatter_')
                graph.add_edge(inp, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(scatter_node, '_gdescriptor', gdesc_access,
                               None,
                               dace.Memlet.from_array(gdesc_name, gdesc_arr))
                graph.add_edge(scatter_node, '_ldescriptor', ldesc_access,
                               None,
                               dace.Memlet.from_array(ldesc_name, ldesc_arr))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(
                        local_access, None, map_entry, e.dst_conn,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.out_edges(map_entry):
                    if e.data.data == inp.data:
                        e.data.data = local_name

            else:
                raise NotImplementedError

        outputs = set()
        for _, _, dst, _, m in graph.out_edges(map_exit):
            if not isinstance(dst, nodes.AccessNode):
                raise NotImplementedError
            desc = dst.desc(sdfg)
            if not isinstance(desc, data.Array):
                raise NotImplementedError
            try:
                if list(desc.shape) != m.dst_subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.dst_subset.size_exact()):
                        raise NotImplementedError
            except AttributeError:
                if list(desc.shape) != m.subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.subset.size_exact()):
                        raise NotImplementedError
            outputs.add(dst)

        for out in outputs:
            desc = out.desc(sdfg)
            if isinstance(desc, data.Scalar):
                raise NotImplementedError
            elif isinstance(desc, data.Array):
                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                scatter_node = BlockCyclicGather('_Gather_')
                graph.add_edge(local_access, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', out, None,
                               dace.Memlet.from_array(out.data, desc))

                for e in graph.edges_between(map_exit, out):
                    graph.add_edge(
                        map_exit, e.src_conn, local_access, None,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.in_edges(map_exit):
                    if e.data.data == out.data:
                        e.data.data = local_name
            else:
                raise NotImplementedError

        map_entry.map.params = params
        map_entry.map.range = subsets.Range(ranges)
示例#6
0
class OnTheFlyMapFusion(Transformation):
    _first_map_entry = nodes.MapEntry(nodes.Map('', [], []))
    _first_tasklet = nodes.Tasklet('')
    _first_map_exit = nodes.MapExit(nodes.Map('', [], []))
    _array_access = nodes.AccessNode('')
    _second_map_entry = nodes.MapEntry(nodes.Map('', [], []))
    _second_tasklet = nodes.Tasklet('')

    @staticmethod
    def expressions():
        return [
            sdutils.node_path_graph(OnTheFlyMapFusion._first_map_entry,
                                    OnTheFlyMapFusion._first_tasklet,
                                    OnTheFlyMapFusion._first_map_exit,
                                    OnTheFlyMapFusion._array_access,
                                    OnTheFlyMapFusion._second_map_entry,
                                    OnTheFlyMapFusion._second_tasklet)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_map_entry = graph.node(
            candidate[OnTheFlyMapFusion._first_map_entry])
        first_tasklet = graph.node(candidate[OnTheFlyMapFusion._first_tasklet])
        first_map_exit = graph.node(
            candidate[OnTheFlyMapFusion._first_map_exit])
        array_access = graph.node(candidate[OnTheFlyMapFusion._array_access])

        if len(first_map_exit.in_connectors) != 1:
            return False

        if (graph.in_degree(array_access) != 1
                or graph.out_degree(array_access) != 1):
            return False
        return True

    @staticmethod
    def _memlet_offsets(base_memlet, offset_memlet):
        """ Compute subset offset of `offset_memlet` relative to `base_memlet`.
        """
        def offset(base_range, offset_range):
            b0, e0, s0 = base_range
            b1, e1, s1 = offset_range
            assert e1 - e0 == b1 - b0 and s0 == s1
            return int(e1 - e0)

        return tuple(
            offset(b, o) for b, o in zip(base_memlet.subset.ranges,
                                         offset_memlet.subset.ranges))

    @staticmethod
    def _update_map_connectors(state, array_access, first_map_entry,
                               second_map_entry):
        """ Remove unused connector (of the to-be-replaced array) from second
            map entry, add new connectors to second map entry for the inputs
            used in the first map’s tasklets.
        """
        # Remove edges and connectors from arrays access to second map entry
        for edge in state.edges_between(array_access, second_map_entry):
            state.remove_edge_and_connectors(edge)
        state.remove_node(array_access)

        # Add new connectors to second map
        # TODO: implement for the general case with random naming
        for edge in state.in_edges(first_map_entry):
            if second_map_entry.add_in_connector(edge.dst_conn):
                state.add_edge(edge.src, edge.src_conn, second_map_entry,
                               edge.dst_conn, edge.data)

    @staticmethod
    def _read_offsets(state, array_name, first_map_exit, second_map_entry):
        """ Compute offsets of read accesses in second map.
        """
        # Get output memlet of first tasklet
        output_edges = state.in_edges(first_map_exit)
        assert len(output_edges) == 1
        write_memlet = output_edges[0].data

        # Find read offsets by looping over second map entry connectors
        offsets = defaultdict(list)
        for edge in state.out_edges(second_map_entry):
            if edge.data.data == array_name:
                second_map_entry.remove_out_connector(edge.src_conn)
                state.remove_edge(edge)
                offset = OnTheFlyMapFusion._memlet_offsets(
                    write_memlet, edge.data)
                offsets[offset].append(edge)

        return offsets

    @staticmethod
    def _copy_first_map_contents(state, first_map_entry, first_map_exit):
        nodes = list(
            state.all_nodes_between(first_map_entry, first_map_exit) -
            {first_map_entry})
        new_nodes = [copy.deepcopy(node) for node in nodes]
        for node in new_nodes:
            state.add_node(node)
        id_map = {
            state.node_id(old): state.node_id(new)
            for old, new in zip(nodes, new_nodes)
        }

        def map(node):
            return state.node(id_map[state.node_id(node)])

        for edge in state.edges():
            if edge.src in nodes or edge.dst in nodes:
                src = map(edge.src) if edge.src in nodes else edge.src
                dst = map(edge.dst) if edge.dst in nodes else edge.dst
                state.add_edge(src, edge.src_conn, dst, edge.dst_conn,
                               copy.deepcopy(edge.data))

        return new_nodes

    def _replicate_first_map(self, sdfg, array_access, first_map_entry,
                             first_map_exit, second_map_entry):
        """ Replicate tasklet of first map for reach read access in second map.
        """
        state = sdfg.node(self.state_id)
        array_name = array_access.data
        array = sdfg.arrays[array_name]

        read_offsets = self._read_offsets(state, array_name, first_map_exit,
                                          second_map_entry)

        # Replicate first map tasklets once for each read offset access and
        # connect them to other tasklets accordingly
        for offset, edges in read_offsets.items():
            nodes = self._copy_first_map_contents(state, first_map_entry,
                                                  first_map_exit)
            tmp_name = sdfg.temp_data_name()
            sdfg.add_scalar(tmp_name, array.dtype, transient=True)
            tmp_access = state.add_access(tmp_name)

            for node in nodes:
                for edge in state.edges_between(node, first_map_exit):
                    state.add_edge(edge.src, edge.src_conn, tmp_access, None,
                                   dace.Memlet(tmp_name))
                    state.remove_edge(edge)

                for edge in state.edges_between(first_map_entry, node):
                    memlet = copy.deepcopy(edge.data)
                    memlet.subset.offset(list(offset), negative=False)
                    second_map_entry.add_out_connector(edge.src_conn)
                    state.add_edge(second_map_entry, edge.src_conn, node,
                                   edge.dst_conn, memlet)
                    state.remove_edge(edge)

            for edge in edges:
                state.add_edge(tmp_access, None, edge.dst, edge.dst_conn,
                               dace.Memlet(tmp_name))

    def apply(self, sdfg: dace.SDFG):
        state = sdfg.node(self.state_id)
        first_map_entry = state.node(self.subgraph[self._first_map_entry])
        first_tasklet = state.node(self.subgraph[self._first_tasklet])
        first_map_exit = state.node(self.subgraph[self._first_map_exit])
        array_access = state.node(self.subgraph[self._array_access])
        second_map_entry = state.node(self.subgraph[self._second_map_entry])

        self._update_map_connectors(state, array_access, first_map_entry,
                                    second_map_entry)

        self._replicate_first_map(sdfg, array_access, first_map_entry,
                                  first_map_exit, second_map_entry)

        state.remove_nodes_from(
            state.all_nodes_between(first_map_entry, first_map_exit)
            | {first_map_exit})
示例#7
0
    def expansion(node, state: SDFGState, sdfg: SDFG):
        # Extract input and output array views (as generated by memlets)
        inputs, outputs = _get_inputs_and_outputs(sdfg, state, node)

        unique_id = "{}_{}_{}_{}".format(clean_onnx_name(node.name),
                                         sdfg.sdfg_id, sdfg.node_id(state),
                                         state.node_id(node))
        _add_ort_init_code(sdfg)

        sdfg.append_global_code(
            "OrtExecutableKernel *__ort_kernel_{};\n".format(unique_id))
        sdfg.append_global_code(
            "OrtExecutableKernelContext *__ort_context_{};\n".format(
                unique_id))

        sdfg.append_init_code("""
        {{
        // Setup for {name}
        __ort_check_status(__ort_api->CreateExecutableKernelContext("{name}", "{op_type}", &__ort_context_{name}));
        """.format(name=unique_id, op_type=node.schema.name))

        # check if ORT supports CUDA for this node
        ##########################################

        # Default: all parameters are on CPU if we execute using cpu
        outputs_on_host = [True for _ in range(len(outputs))]
        inputs_on_host = [True for _ in range(len(inputs))]

        actual_node_schedule = node.schedule
        if node.schedule == ScheduleType.CPU_Multicore or node.schedule == ScheduleType.Default:
            provider_index = 0
        elif node.schedule == ScheduleType.GPU_Device:
            provider_index = 1
            try:
                # the ith position indicates whether the ith output is in host memory
                inputs_on_host, outputs_on_host = check_op(sdfg,
                                                           state,
                                                           node,
                                                           cuda=True)

            except ONNXOpValidationError as e:
                # fallback to CPU
                print("Falling back to CPU for node {}. Reason:\n{}".format(
                    node.name, str(e)))
                provider_index = 0
                actual_node_schedule = ScheduleType.Default
        else:
            raise NotImplementedError(
                "ORT expansion for schedule '{}' is not implemented".format(
                    node.schedule))

        # check if we need to insert device copies
        ##########################################

        # maps the connectors for which a copy will be required to the storage type required to be connected to the tasklet
        input_copy_required = defaultdict(dict)
        output_copy_required = defaultdict(dict)

        assert len(
            node.iter_outputs_in_onnx_order(state)) == len(outputs_on_host)
        assert len(
            node.iter_inputs_in_onnx_order(state)) == len(inputs_on_host)

        # check outputs
        for edge, output_on_host in zip(node.iter_outputs_in_onnx_order(state),
                                        outputs_on_host):
            # get the memlet for this output
            array = sdfg.arrays[edge.data.data]

            if output_on_host:
                is_device_mismatch = not can_access(ScheduleType.Default,
                                                    array.storage)
            else:
                is_device_mismatch = not can_access(ScheduleType.GPU_Device,
                                                    array.storage)

            if isinstance(
                    array, dt.Scalar
            ) and actual_node_schedule == ScheduleType.GPU_Device:
                # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this
                is_device_mismatch = True
                output_copy_required[edge.src_conn]['copy_to_array'] = True

            if is_device_mismatch:
                # we need to insert a copy
                output_copy_required[edge.src_conn][
                    'storage'] = StorageType.Default if output_on_host else StorageType.GPU_Global

        # check inputs (same thing again)
        for edge, input_on_host in zip(node.iter_inputs_in_onnx_order(state),
                                       inputs_on_host):
            array = sdfg.arrays[edge.data.data]

            if input_on_host:
                is_device_mismatch = not can_access(ScheduleType.Default,
                                                    array.storage)
            else:
                is_device_mismatch = not can_access(ScheduleType.GPU_Device,
                                                    array.storage)

            if isinstance(
                    array, dt.Scalar
            ) and actual_node_schedule == ScheduleType.GPU_Device:
                # ORT kernels expect scalars to be cudaMalloced. We will copy during expansion to enforce this
                is_device_mismatch = True
                input_copy_required[edge.dst_conn]['copy_to_array'] = True

            if is_device_mismatch:
                # we need to insert a copy
                input_copy_required[edge.dst_conn][
                    'storage'] = StorageType.Default if input_on_host else StorageType.GPU_Global

        # begin codegen
        ##########################################
        tasklet_setup_code = ""
        tasklet_code = ""
        tasklet_cleanup_code = ""

        reversed_onnx_dtype_map = {
            v: k
            for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items()
        }

        # emit code for inputs and outputs
        ##########################################
        in_connectors = {}
        out_connectors = {}

        for edge, is_input in node.iter_edges(state):

            parameter_name = edge.dst_conn if is_input else edge.src_conn

            if len(output_copy_required) != 0 or len(input_copy_required) != 0:
                edge_connector_name = "_conn_" + parameter_name
            else:
                edge_connector_name = parameter_name

            input_output_string = "input" if is_input else "output"
            connector_dict = in_connectors if is_input else out_connectors
            memlet = edge.data
            desc = sdfg.arrays[memlet.data]
            sdfg.append_init_code("""
            // Add parameter {parameter_name}
            __ort_check_status(__ort_api->ExecutableKernelContext_Add{input_output_string}(__ort_context_{id}, ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_string}));
            """.format(id=unique_id,
                       type_string=reversed_onnx_dtype_map[desc.dtype].upper(),
                       parameter_name=parameter_name,
                       input_output_string=input_output_string.capitalize()))

            ort_value_name = "ort_value_{input_output_string}_{parameter_name}".format(
                input_output_string=input_output_string,
                parameter_name=parameter_name)

            copy_to_array = (
                (parameter_name in output_copy_required
                 and 'copy_to_array' in output_copy_required[parameter_name])
                or
                (parameter_name in input_copy_required
                 and 'copy_to_array' in input_copy_required[parameter_name]))
            if desc.storage == StorageType.Default:
                mem_info = "__ort_cpu_mem_info"
            elif desc.storage == StorageType.GPU_Global:
                mem_info = "__ort_cuda_mem_info"
            elif desc.storage == StorageType.CPU_Pinned:
                mem_info = "__ort_cuda_pinned_mem_info"
            else:
                raise ValueError(
                    "Unsupported storage type {} for input to ONNX node".
                    format(desc.storage))
            if (isinstance(desc, dt.Scalar) and
                    # when copying to array, the ort value is not a scalar but an array
                    not copy_to_array):

                tasklet_setup_code += """
                OrtValue* {ort_value_name};
                __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue(
                    {mem_info},
                    &{edge_connector_name},
                    {data_size} * sizeof({ctype}),
                    nullptr,
                    0,
                    ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str},
                    &{ort_value_name}
                ));
                """.format(
                    input_output_string=input_output_string,
                    mem_info=mem_info,
                    edge_connector_name=edge_connector_name,
                    data_size=reduce(lambda x, y: x * y, desc.shape),
                    ctype=desc.dtype.ctype,
                    type_str=reversed_onnx_dtype_map[desc.dtype].upper(),
                    ort_value_name=ort_value_name)
                connector_dict[parameter_name] = None

            elif isinstance(desc, dt.Array) or copy_to_array:

                # when we copy a scalar to an array, that scalar ofc has shape []
                dims = [] if copy_to_array else desc.shape

                # setup dims array
                tasklet_setup_code += """
                int64_t {input_output_string}_{parameter_name}_dims[{dims_size}] = {{{dims}}};
                """.format(input_output_string=input_output_string,
                           parameter_name=parameter_name,
                           dims_size=len(dims),
                           dims=", ".join(str(s) for s in dims))

                connector_dict[parameter_name] = dace.pointer(desc.dtype)
                data = "const_cast < void * > (reinterpret_cast < const void * > ({}))".format(
                    edge_connector_name)

                tasklet_setup_code += """
                OrtValue* {ort_value_name};
                __ort_check_status(__ort_api->CreateTensorWithDataAsOrtValue(
                    {mem_info},
                    {data},
                    {data_size} * sizeof({ctype}),
                    {input_output_string}_{parameter_name}_dims,
                    {dims_size},
                    ONNX_TENSOR_ELEMENT_DATA_TYPE_{type_str},
                    &{ort_value_name}
                ));
                """.format(
                    input_output_string=input_output_string,
                    data=data,
                    mem_info=mem_info,
                    parameter_name=parameter_name,
                    data_size=reduce(lambda x, y: x * y, desc.shape),
                    ctype=desc.dtype.ctype,
                    dims_size=len(dims),
                    type_str=reversed_onnx_dtype_map[desc.dtype].upper(),
                    ort_value_name=ort_value_name)
            else:
                raise NotImplementedError(
                    "Data-descriptor type {} not supported for ONNX nodes".
                    format(type(desc)))


            tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Set{input_output_string_capital}(" \
                            "__ort_kernel_{unique_id}, {position}, {ort_value_name}));\n".format(
                input_output_string_capital=input_output_string.
                    capitalize(),
                ort_value_name=ort_value_name,
                unique_id=unique_id,
                position=get_position(node.schema, is_input,
                                      parameter_name))

            tasklet_cleanup_code += "__ort_api->ReleaseValue(ort_value_{input_output_string}_{parameter_name});\n".format(
                input_output_string=input_output_string,
                parameter_name=parameter_name)

        sdfg.append_init_code("// Setup attributes\n")

        for name, attr in node.schema.attributes.items():
            if hasattr(node, name):
                sdfg.append_init_code(
                    _gen_attr_init_code("__ort_context_{}".format(unique_id),
                                        node.schema.attributes[name],
                                        getattr(node, name)))

        sdfg.prepend_exit_code(
            "__ort_api->ReleaseExecutableKernelContext(__ort_context_{});\n".
            format(unique_id))
        sdfg.prepend_exit_code(
            "__ort_api->ReleaseExecutableKernel(__ort_kernel_{});\n".format(
                unique_id))

        tasklet_code += 'fprintf(stderr, "Launching {}\\n");\n'.format(
            unique_id)
        tasklet_code += "__ort_check_status(__ort_api->ExecutableKernel_Compute(__ort_kernel_{}));\n".format(
            unique_id)

        sdfg.append_init_code(
            "__ort_check_status(__ort_api->CreateExecutableKernel("
            "__ort_session, __ort_context_{id}, /*provider_index=*/{provider_index}, &__ort_kernel_{id}));\n"
            .format(provider_index=provider_index, id=unique_id))
        sdfg.append_init_code(
            "}} // end setup for context_{}".format(unique_id))

        tasklet_code = tasklet_setup_code + tasklet_code + tasklet_cleanup_code
        tasklet = nd.Tasklet('onnx_code',
                             in_connectors,
                             out_connectors,
                             tasklet_code,
                             language=dace.dtypes.Language.CPP)
        tasklet.environments = {"ONNXRuntime"}

        if len(output_copy_required) != 0 or len(input_copy_required) != 0:
            nsdfg = dace.SDFG("nested_{}".format(unique_id))
            nstate = nsdfg.add_state()
            ntasklet = deepcopy(tasklet)

            # add a prefix to connectors to prevent shadowing of array names
            ntasklet.in_connectors = {
                "_conn_" + k: v
                for k, v in tasklet.in_connectors.items()
            }
            ntasklet.out_connectors = {
                "_conn_" + k: v
                for k, v in tasklet.out_connectors.items()
            }

            nstate.add_node(ntasklet)

            for edge, is_input in node.iter_edges(state):
                parameter_name = edge.dst_conn if is_input else edge.src_conn

                memlet = edge.data
                desc = sdfg.arrays[memlet.data]

                # add the original array
                original_desc = deepcopy(desc)
                original_desc.transient = False
                nsdfg.add_datadesc(parameter_name, original_desc)
                if not (isinstance(desc, dt.Array)
                        or isinstance(desc, dt.Scalar)):
                    raise ValueError(
                        "Unsupported data type {} connected to an ONNX tasklet"
                        .format(type(desc)))

                if parameter_name not in (input_copy_required if is_input else
                                          output_copy_required):
                    if is_input:
                        access = nstate.add_read(parameter_name)
                        nstate.add_edge(access, None, ntasklet,
                                        "_conn_" + parameter_name,
                                        nsdfg.get_array_memlet(parameter_name))
                    else:
                        access = nstate.add_write(parameter_name)
                        nstate.add_edge(ntasklet, "_conn_" + parameter_name,
                                        access, None,
                                        nsdfg.get_array_memlet(parameter_name))
                    continue

                copy_options = input_copy_required[
                    parameter_name] if is_input else output_copy_required[
                        parameter_name]

                # add the copy of the descriptor
                if 'copy_to_array' in copy_options:
                    copy_desc = dt.Array(shape=[1], dtype=desc.dtype)
                else:
                    copy_desc = deepcopy(desc)

                copy_desc.transient = True
                copy_desc.storage = copy_options['storage']
                nsdfg.add_datadesc("copy_" + memlet.data, copy_desc)

                nmemlet = deepcopy(memlet)
                nmemlet.data = "copy_" + nmemlet.data
                if is_input:
                    access = nstate.add_read(parameter_name)
                    access_copy = nstate.add_access("copy_" + memlet.data)
                    nstate.add_edge(
                        access, None, access_copy, None,
                        nsdfg.get_array_memlet("copy_" + memlet.data))
                    nstate.add_edge(access_copy, None, ntasklet,
                                    "_conn_" + parameter_name, nmemlet)
                else:
                    access = nstate.add_write(parameter_name)
                    access_copy = nstate.add_access("copy_" + memlet.data)
                    nstate.add_edge(ntasklet, "_conn_" + parameter_name,
                                    access_copy, None, nmemlet)
                    nstate.add_edge(
                        access_copy, None, access, None,
                        nsdfg.get_array_memlet("copy_" + memlet.data))

            return nsdfg

        else:
            return tasklet
示例#8
0
class MapReduceFusion(pm.Transformation):
    """ Implements the map-reduce-fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else.
    """

    no_init = Property(dtype=bool,
                       default=False,
                       desc='If enabled, does not create initialization states '
                       'for reduce nodes with identity')

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')

    import dace.libraries.standard as stdlib  # Avoid import loop
    _reduce = stdlib.Reduce()

    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            sdutil.node_path_graph(MapReduceFusion._tasklet,
                                   MapReduceFusion._tmap_exit,
                                   MapReduceFusion._in_array,
                                   MapReduceFusion._reduce,
                                   MapReduceFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        tmap_exit = graph.nodes()[candidate[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[candidate[MapReduceFusion._reduce]]
        tasklet = graph.nodes()[candidate[MapReduceFusion._tasklet]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != reduce_node
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        tmem = next(e for e in graph.edges_between(tasklet, tmap_exit)
                    if e.data.data == in_array.data).data

        # Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if not permissive and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # If memlet already has WCR and it is different from reduce node,
        # do not match
        if tmem.wcr is not None and tmem.wcr != reduce_node.wcr:
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapReduceFusion._tasklet]
        map_exit = candidate[MapReduceFusion._tmap_exit]
        reduce = candidate[MapReduceFusion._reduce]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg: SDFG):
        graph = sdfg.nodes()[self.state_id]
        tmap_exit = graph.nodes()[self.subgraph[MapReduceFusion._tmap_exit]]
        in_array = graph.nodes()[self.subgraph[MapReduceFusion._in_array]]
        reduce_node = graph.nodes()[self.subgraph[MapReduceFusion._reduce]]
        out_array = graph.nodes()[self.subgraph[MapReduceFusion._out_array]]

        # Set nodes to remove according to the expression index
        nodes_to_remove = [in_array]
        nodes_to_remove.append(reduce_node)

        memlet_edge = None
        for edge in graph.in_edges(tmap_exit):
            if edge.data.data == in_array.data:
                memlet_edge = edge
                break
        if memlet_edge is None:
            raise RuntimeError('Reduction memlet cannot be None')

        # Find which indices should be removed from new memlet
        input_edge = graph.in_edges(reduce_node)[0]
        axes = reduce_node.axes or list(range(len(input_edge.data.subset)))
        array_edge = graph.out_edges(reduce_node)[0]

        # Delete relevant edges and nodes
        graph.remove_nodes_from(nodes_to_remove)

        # Delete relevant data descriptors
        for node in set(nodes_to_remove):
            if isinstance(node, nodes.AccessNode):
                # try to delete it
                try:
                    sdfg.remove_data(node.data)
                # will raise ValueError if the datadesc is used somewhere else
                except ValueError:
                    pass

        # Filter out reduced dimensions from subset
        filtered_subset = [
            dim for i, dim in enumerate(memlet_edge.data.subset)
            if i not in axes
        ]
        if len(filtered_subset) == 0:  # Output is a scalar
            filtered_subset = [(0, 0, 1)]

        # Modify edge from tasklet to map exit
        memlet_edge.data.data = out_array.data
        memlet_edge.data.wcr = reduce_node.wcr
        memlet_edge.data.subset = type(memlet_edge.data.subset)(filtered_subset)

        # Add edge from map exit to output array
        graph.add_edge(
            memlet_edge.dst, 'OUT_' + memlet_edge.dst_conn[3:], array_edge.dst,
            array_edge.dst_conn,
            Memlet.simple(array_edge.data.data,
                          array_edge.data.subset,
                          num_accesses=array_edge.data.num_accesses,
                          wcr_str=reduce_node.wcr))

        # Add initialization state as necessary
        if not self.no_init and reduce_node.identity is not None:
            init_state = sdfg.add_state_before(graph)
            init_state.add_mapped_tasklet(
                'freduce_init',
                [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2]))
                 for i, r in enumerate(array_edge.data.subset)], {},
                'out = %s' % reduce_node.identity, {
                    'out':
                    Memlet.simple(
                        array_edge.data.data, ','.join([
                            'o%d' % i
                            for i in range(len(array_edge.data.subset))
                        ]))
                },
                external_edges=True)
示例#9
0
class MapWCRFusion(pm.Transformation):
    """ Implements the map expanded-reduce fusion transformation.
        Fuses a map with an immediately following reduction, where the array
        between the map and the reduction is not used anywhere else, and the
        reduction is divided to two maps with a WCR, denoting partial reduction.
    """

    _tasklet = nodes.Tasklet('_')
    _tmap_exit = nodes.MapExit(nodes.Map("", [], []))
    _in_array = nodes.AccessNode('_')
    _rmap_in_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_in_tasklet = nodes.Tasklet('_')
    _rmap_in_cr = nodes.MapExit(nodes.Map("", [], []))
    _rmap_out_entry = nodes.MapEntry(nodes.Map("", [], []))
    _rmap_out_exit = nodes.MapExit(nodes.Map("", [], []))
    _out_array = nodes.AccessNode('_')

    @staticmethod
    def expressions():
        return [
            # Map, then partial reduction of axes
            sdutil.node_path_graph(
                MapWCRFusion._tasklet, MapWCRFusion._tmap_exit,
                MapWCRFusion._in_array, MapWCRFusion._rmap_out_entry,
                MapWCRFusion._rmap_in_entry, MapWCRFusion._rmap_in_tasklet,
                MapWCRFusion._rmap_in_cr, MapWCRFusion._rmap_out_exit,
                MapWCRFusion._out_array)
        ]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        tmap_exit = graph.nodes()[candidate[MapWCRFusion._tmap_exit]]
        in_array = graph.nodes()[candidate[MapWCRFusion._in_array]]
        rmap_entry = graph.nodes()[candidate[MapWCRFusion._rmap_out_entry]]

        # Make sure that the array is only accessed by the map and the reduce
        if any([
                src != tmap_exit
                for src, _, _, _, memlet in graph.in_edges(in_array)
        ]):
            return False
        if any([
                dest != rmap_entry
                for _, _, dest, _, memlet in graph.out_edges(in_array)
        ]):
            return False

        # Make sure that there is a reduction in the second map
        rmap_cr = graph.nodes()[candidate[MapWCRFusion._rmap_in_cr]]
        reduce_edge = graph.in_edges(rmap_cr)[0]
        if reduce_edge.data.wcr is None:
            return False

        # Make sure that the transient is not accessed anywhere else
        # in this state or other states
        if not permissive and (len([
                n for n in graph.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == in_array.data
        ]) > 1 or in_array.data in sdfg.shared_transients()):
            return False

        # Verify that reduction ranges match tasklet map
        tout_memlet = graph.in_edges(in_array)[0].data
        rin_memlet = graph.out_edges(in_array)[0].data
        if tout_memlet.subset != rin_memlet.subset:
            return False

        return True

    @staticmethod
    def match_to_str(graph, candidate):
        tasklet = candidate[MapWCRFusion._tasklet]
        map_exit = candidate[MapWCRFusion._tmap_exit]
        reduce = candidate[MapWCRFusion._rmap_in_cr]

        return ' -> '.join(str(node) for node in [tasklet, map_exit, reduce])

    def apply(self, sdfg):
        graph = sdfg.node(self.state_id)

        # To apply, collapse the second map and then fuse the two resulting maps
        map_collapse = MapCollapse(
            self.sdfg_id, self.state_id, {
                MapCollapse._outer_map_entry:
                self.subgraph[MapWCRFusion._rmap_out_entry],
                MapCollapse._inner_map_entry:
                self.subgraph[MapWCRFusion._rmap_in_entry]
            }, 0)
        map_entry, _ = map_collapse.apply(sdfg)

        map_fusion = MapFusion(
            self.sdfg_id, self.state_id, {
                MapFusion.first_map_exit:
                self.subgraph[MapWCRFusion._tmap_exit],
                MapFusion.second_map_entry: graph.node_id(map_entry)
            }, 0)
        map_fusion.apply(sdfg)
示例#10
0
文件: reduce.py 项目: thobauma/dace
    def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG, **kwargs):

        node.validate(sdfg, state)
        for edge in state.in_edges(node):
            if edge.dst_conn == '_inbuffer':
                input_edge = edge
        for edge in state.out_edges(node):
            if edge.src_conn == '_outbuffer':
                output_edge = edge
        input_dims = input_edge.data.subset.size_exact()
        output_dims = output_edge.data.subset.size_exact()
        input_data = sdfg.arrays[input_edge.data.data]
        output_data = sdfg.arrays[output_edge.data.data]

        # Verify that data is on the GPU
        if input_data.storage is not dtypes.StorageType.GPU_Global:
            raise ValueError('Input of NCCL Send must reside '
                             ' in global GPU memory.')
        if output_data.storage is not dtypes.StorageType.GPU_Global:
            raise ValueError('Output of NCCL Recv must reside '
                             ' in global GPU memory.')

        root = node.root
        rootstr = str(root)
        for fs in root.free_symbols:
            if fs.name in sdfg.arrays:
                sdfg.arrays[fs.name].lifetime = dtypes.AllocationLifetime.SDFG
            if fs.name in sdfg.parent_sdfg.arrays:
                sdfg.parent_sdfg.arrays[
                    fs.name].lifetime = dtypes.AllocationLifetime.SDFG
        redtype = node.reduction_type
        redtype = nutil.NCCL_SUPPORTED_OPERATIONS[redtype]
        wcr_str = str(redtype)
        wcr_str = wcr_str[wcr_str.find('.') + 1:]  # Skip "NcclReductionType."

        nccl_dtype_str = nutil.Nccl_dtypes(input_data.dtype.base_type)
        count_str = "*".join(str(e) for e in input_dims)

        if input_data.dtype.veclen > 1:
            raise (NotImplementedError)

        code = f"""ncclReduce(_inbuffer, _outbuffer, {count_str}, {nccl_dtype_str}, {wcr_str}, {rootstr}, __state->ncclCommunicators->at(__dace_cuda_device),  __dace_current_stream)"""
        if Config.get('compiler', 'build_type') == 'Debug':
            code = '''DACE_NCCL_CHECK(''' + code + ''');\n'''

        else:
            code = code + ''';\n'''

        if Config.get_bool('debugprint'):
            code = (
                f'''printf("{str(node)}: begin;  dev,peer: %d, %d\\n", __dace_cuda_device, {rootstr});\n'''
                + code +
                f'''printf("{str(node)}: end;  dev,peer: %d, %d\\n\\n", __dace_cuda_device, {rootstr});\n'''
            )
        code += """\ncudaStreamSynchronize(__dace_current_stream);"""

        tasklet = nodes.Tasklet(node.name + "_" + wcr_str,
                                node.in_connectors,
                                node.out_connectors,
                                code,
                                location=node.location,
                                language=dtypes.Language.CPP,
                                library_expansion_symbols=set(
                                    map(str, root.free_symbols)))

        return tasklet