Exemple #1
0
def consolidate_edges_scope(
        state: SDFGState, scope_node: Union[nd.EntryNode, nd.ExitNode]) -> int:
    """
        Union scope-entering memlets relating to the same data node in a scope.
        This effectively reduces the number of connectors and allows more
        transformations to be performed, at the cost of losing the individual
        per-tasklet memlets.
        :param state: The SDFG state in which the scope to consolidate resides.
        :param scope_node: The scope node whose edges will be consolidated.
        :return: Number of edges removed.
    """
    if scope_node is None:
        return 0
    data_to_conn = {}
    consolidated = 0
    if isinstance(scope_node, nd.EntryNode):
        outer_edges = state.in_edges
        inner_edges = state.out_edges
        remove_outer_connector = scope_node.remove_in_connector
        remove_inner_connector = scope_node.remove_out_connector
        prefix, oprefix = 'IN_', 'OUT_'
    else:
        outer_edges = state.out_edges
        inner_edges = state.in_edges
        remove_outer_connector = scope_node.remove_out_connector
        remove_inner_connector = scope_node.remove_in_connector
        prefix, oprefix = 'OUT_', 'IN_'

    edges_by_connector = collections.defaultdict(list)
    connectors_to_remove = set()
    for e in inner_edges(scope_node):
        edges_by_connector[e.src_conn].append(e)
        if e.data.data not in data_to_conn:
            data_to_conn[e.data.data] = e.src_conn
        elif data_to_conn[e.data.data] != e.src_conn:  # Need to consolidate
            connectors_to_remove.add(e.src_conn)

    for conn in connectors_to_remove:
        e = edges_by_connector[conn][0]
        # Outer side of the scope - remove edge and union subsets
        target_conn = prefix + data_to_conn[e.data.data][len(oprefix):]
        conn_to_remove = prefix + conn[len(oprefix):]
        remove_outer_connector(conn_to_remove)
        out_edge = next(ed for ed in outer_edges(scope_node)
                        if ed.dst_conn == target_conn)
        edge_to_remove = next(ed for ed in outer_edges(scope_node)
                              if ed.dst_conn == conn_to_remove)
        out_edge.data.subset = sbs.union(out_edge.data.subset,
                                         edge_to_remove.data.subset)
        state.remove_edge(edge_to_remove)
        consolidated += 1
        # Inner side of the scope - remove and reconnect
        remove_inner_connector(e.src_conn)
        for e in edges_by_connector[conn]:
            e._src_conn = data_to_conn[e.data.data]

    return consolidated
Exemple #2
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        tasklet = self.tasklet
        map_exit = self.map_exit
        outer_map_exit = self.outer_map_exit
        memlet = None
        edge = None
        for e in graph.out_edges(map_exit):
            memlet = e.data
            # TODO: What if there's more than one?
            if e.dst == outer_map_exit and isinstance(sdfg.arrays[memlet.data],
                                                      data.Stream):
                edge = e
                break
        tasklet_memlet = None
        for e in graph.out_edges(tasklet):
            tasklet_memlet = e.data
            if tasklet_memlet.data == memlet.data:
                break

        bbox = map_exit.map.range.bounding_box_size()
        bbox_approx = [symbolic.overapproximate(dim) for dim in bbox]
        dataname = memlet.data

        # Create the new node: Temporary stream and an access node
        newname, _ = sdfg.add_stream('trans_' + dataname,
                                     sdfg.arrays[memlet.data].dtype,
                                     bbox_approx[0],
                                     storage=sdfg.arrays[memlet.data].storage,
                                     transient=True,
                                     find_new_name=True)
        snode = graph.add_access(newname)

        to_stream_mm = copy.deepcopy(memlet)
        to_stream_mm.data = snode.data
        tasklet_memlet.data = snode.data

        if self.with_buffer:
            newname_arr, _ = sdfg.add_transient('strans_' + dataname,
                                                [bbox_approx[0]],
                                                sdfg.arrays[memlet.data].dtype,
                                                find_new_name=True)
            anode = graph.add_access(newname_arr)
            to_array_mm = copy.deepcopy(memlet)
            to_array_mm.data = anode.data
            graph.add_edge(snode, None, anode, None, to_array_mm)
        else:
            anode = snode

        # Reconnect, assuming one edge to the stream
        graph.remove_edge(edge)
        graph.add_edge(map_exit, edge.src_conn, snode, None, to_stream_mm)
        graph.add_edge(anode, None, outer_map_exit, edge.dst_conn, memlet)

        return
Exemple #3
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        if self.expr_index == 0:
            map_entry = self.map_entry
            nsdfg_node = helpers.nest_state_subgraph(
                sdfg,
                graph,
                graph.scope_subgraph(map_entry),
                full_data=self.fullcopy)
        else:
            cnode = self.reduce
            nsdfg_node = helpers.nest_state_subgraph(sdfg,
                                                     graph,
                                                     SubgraphView(
                                                         graph, [cnode]),
                                                     full_data=self.fullcopy)

        # Avoiding import loops
        from dace.transformation.interstate import GPUTransformSDFG
        transformation = GPUTransformSDFG(sdfg, 0, -1, {}, 0)
        transformation.register_trans = self.register_trans
        transformation.sequential_innermaps = self.sequential_innermaps
        transformation.toplevel_trans = self.toplevel_trans

        transformation.apply(nsdfg_node.sdfg, nsdfg_node.sdfg)

        # Inline back as necessary
        sdfg.simplify()
Exemple #4
0
def is_parallel(state: SDFGState, node: Optional[nd.Node] = None) -> bool:
    """
    Returns True if a node or state are contained within a parallel
    section.
    :param state: The state to test.
    :param node: An optional node in the state to test. If None, only checks
                 state.
    :return: True if the state or node are located within a map scope that
             is scheduled to run in parallel, False otherwise.
    """
    if node is not None:
        sdict = state.scope_dict()
        curnode = node
        while curnode is not None:
            curnode = sdict[curnode]
            if curnode.schedule != dtypes.ScheduleType.Sequential:
                return True
    if state.parent.parent is not None:
        # Find nested SDFG node and continue recursion
        nsdfg_node = next(
            n for n in state.parent.parent
            if isinstance(n, nd.NestedSDFG) and n.sdfg == state.parent)
        return is_parallel(state.parent.parent, nsdfg_node)

    return False
Exemple #5
0
    def apply(self, state: SDFGState, sdfg: SDFG):
        map_entry = self.map_entry
        current_map = map_entry.map

        # Expand the innermost map if multidimensional
        if len(current_map.params) > 1:
            ext, rem = dace.transformation.helpers.extract_map_dims(
                sdfg, map_entry, list(range(len(current_map.params) - 1)))
            map_entry = rem
            current_map = map_entry.map

        subgraph = state.scope_subgraph(map_entry)

        # Set the schedule
        current_map.schedule = dace.dtypes.ScheduleType.SVE_Map

        # Infer all connector types and apply them
        inferred = infer_types.infer_connector_types(sdfg, state, subgraph)
        infer_types.apply_connector_types(inferred)

        # Infer vector connectors and AccessNodes and apply them
        vector_inference.infer_vectors(
            sdfg,
            state,
            map_entry,
            self.vec_len,
            flags=vector_inference.VectorInferenceFlags.Allow_Stride,
            apply=True)
Exemple #6
0
def is_array_stream_view(sdfg: SDFG, dfg: SDFGState, node: nd.AccessNode):
    """ Test whether a stream is directly connected to an array. """

    # Test all memlet paths from the array. If the path goes directly
    # to/from a stream, construct a stream array view
    all_source_paths = []
    source_paths = []
    all_sink_paths = []
    sink_paths = []
    for e in dfg.in_edges(node):
        src_node = dfg.memlet_path(e)[0].src
        # Append empty path to differentiate between a copy and an array-view
        if isinstance(src_node, nd.CodeNode):
            all_source_paths.append(None)
        # Append path from source node
        if isinstance(src_node, nd.AccessNode) and isinstance(
                src_node.desc(sdfg), dt.Array):
            source_paths.append(src_node)
    for e in dfg.out_edges(node):
        sink_node = dfg.memlet_path(e)[-1].dst

        # Append empty path to differentiate between a copy and an array-view
        if isinstance(sink_node, nd.CodeNode):
            all_sink_paths.append(None)
        # Append path to sink node
        if isinstance(sink_node, nd.AccessNode) and isinstance(
                sink_node.desc(sdfg), dt.Array):
            sink_paths.append(sink_node)

    all_sink_paths.extend(sink_paths)
    all_source_paths.extend(source_paths)

    # Special case: stream can be represented as a view of an array
    if ((len(all_source_paths) > 0 and len(sink_paths) == 1)
            or (len(all_sink_paths) > 0 and len(source_paths) == 1)):
        # TODO: What about a source path?
        arrnode = sink_paths[0]
        # Only works if the stream itself is not an array of streams
        if list(node.desc(sdfg).shape) == [1]:
            node.desc(sdfg).sink = arrnode.data  # For memlet generation
            arrnode.desc(
                sdfg).src = node.data  # TODO: Move src/sink to node, not array
            return True
    return False
Exemple #7
0
    def memlets_intersect(graph_a: SDFGState, group_a: List[nodes.AccessNode],
                          inputs_a: bool, graph_b: SDFGState,
                          group_b: List[nodes.AccessNode],
                          inputs_b: bool) -> bool:
        """
        Performs an all-pairs check for subset intersection on two
        groups of nodes. If group intersects or result is indeterminate,
        returns True as a precaution.
        :param graph_a: The graph in which the first set of nodes reside.
        :param group_a: The first set of nodes to check.
        :param inputs_a: If True, checks inputs of the first group.
        :param graph_b: The graph in which the second set of nodes reside.
        :param group_b: The second set of nodes to check.
        :param inputs_b: If True, checks inputs of the second group.
        :returns True if subsets intersect or result is indeterminate.
        """
        # Set traversal functions
        src_subset = lambda e: (e.data.src_subset if e.data.src_subset is
                                not None else e.data.dst_subset)
        dst_subset = lambda e: (e.data.dst_subset if e.data.dst_subset is
                                not None else e.data.src_subset)
        if inputs_a:
            edges_a = [e for n in group_a for e in graph_a.out_edges(n)]
            subset_a = src_subset
        else:
            edges_a = [e for n in group_a for e in graph_a.in_edges(n)]
            subset_a = dst_subset
        if inputs_b:
            edges_b = [e for n in group_b for e in graph_b.out_edges(n)]
            subset_b = src_subset
        else:
            edges_b = [e for n in group_b for e in graph_b.in_edges(n)]
            subset_b = dst_subset

        # Simple all-pairs check
        for ea in edges_a:
            for eb in edges_b:
                result = subsets.intersects(subset_a(ea), subset_b(eb))
                if result is True or result is None:
                    return True
        return False
Exemple #8
0
 def _check_all_paths(self, first_state: SDFGState, second_state: SDFGState,
                      match_nodes: Dict[nodes.AccessNode, nodes.AccessNode], nodes_first: List[nodes.AccessNode],
                      nodes_second: List[nodes.AccessNode], first_read: bool, second_read: bool) -> bool:
     for node_a in nodes_first:
         succ_a = first_state.successors(node_a)
         for node_b in nodes_second:
             if all(self.has_path(first_state, second_state, match_nodes, sa, node_b) for sa in succ_a):
                 return True
     # Path not found, check memlets
     if StateFusion.memlets_intersect(first_state, nodes_first, first_read, second_state, nodes_second, second_read):
         return False
     return True
Exemple #9
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        in_array = self.in_array
        med_array = self.med_array
        out_array = self.out_array

        # Modify all edges that point to in_array to point to out_array
        for in_edge in graph.in_edges(in_array):

            # Make all memlets that write to in_array write to out_array instead
            tree = graph.memlet_tree(in_edge)
            for te in tree:
                if te.data.data == in_array.data:
                    te.data.data = out_array.data

            # Redirect edge to in_array
            graph.remove_edge(in_edge)
            graph.add_edge(in_edge.src, in_edge.src_conn, out_array, None,
                           in_edge.data)

        graph.remove_node(med_array)
        graph.remove_node(in_array)
Exemple #10
0
def dynamic_map_inputs(state: SDFGState,
                       map_entry: nd.MapEntry) -> List[gr.MultiConnectorEdge]:
    """
    For a given map entry node, returns a list of dynamic-range input edges.
    :param state: The state in which the map entry node resides.
    :param map_entry: The given node.
    :return: A list of edges in state whose destination is map entry and denote
             dynamic-range input memlets.
    """
    return [
        e for e in state.in_edges(map_entry)
        if e.dst_conn and not e.dst_conn.startswith('IN_')
    ]
Exemple #11
0
def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge):
    """
    Removes an edge and all of its parent edges in a memlet path, cleaning
    dangling connectors and isolated nodes resulting from the removal.
    :param state: The state in which the edge exists.
    :param edge: The edge to remove.
    """
    mtree = state.memlet_tree(edge)
    inwards = (isinstance(edge.src, nd.EntryNode)
               or isinstance(edge.dst, nd.EntryNode))

    # Traverse tree upwards, removing edges and connectors as necessary
    curedge = mtree
    while curedge is not None:
        e = curedge.edge
        state.remove_edge(e)
        if inwards:
            neighbors = [
                neighbor for neighbor in state.out_edges(e.src)
                if e.src_conn == neighbor.src_conn
            ]
        else:
            neighbors = [
                neighbor for neighbor in state.in_edges(e.dst)
                if e.dst_conn == neighbor.dst_conn
            ]
        if len(neighbors) > 0:  # There are still edges connected, leave as-is
            break

        # Remove connector and matching outer connector
        if inwards:
            if e.src_conn:
                e.src.remove_out_connector(e.src_conn)
                e.src.remove_in_connector('IN' + e.src_conn[3:])
        else:
            if e.dst_conn:
                e.dst.remove_in_connector(e.dst_conn)
                e.src.remove_out_connector('OUT' + e.dst_conn[2:])

        # Continue traversing upwards
        curedge = curedge.parent
    else:
        # Check if an isolated node have been created at the root and remove
        root_edge = mtree.root().edge
        root_node: nd.Node = root_edge.src if inwards else root_edge.dst
        if state.degree(root_node) == 0:
            state.remove_node(root_node)
Exemple #12
0
    def apply(self, graph: SDFGState,
              sdfg: SDFG) -> Tuple[nodes.MapEntry, nodes.MapExit]:
        """
        Collapses two maps into one.
        :param sdfg: The SDFG to apply the transformation to.
        :return: A 2-tuple of the new map entry and exit nodes.
        """
        # Extract the parameters and ranges of the inner/outer maps.
        outer_map_entry = self.outer_map_entry
        inner_map_entry = self.inner_map_entry
        inner_map_exit = graph.exit_node(inner_map_entry)
        outer_map_exit = graph.exit_node(outer_map_entry)

        return sdutil.merge_maps(graph, outer_map_entry, outer_map_exit,
                                 inner_map_entry, inner_map_exit)
Exemple #13
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        first_map_entry = graph.entry_node(self.first_map_exit)

        intermediate_dnodes = set()
        for _, _, node, _, _ in graph.out_edges(self.first_map_exit):
            if not isinstance(node, nds.AccessNode):
                continue

            intermediate_dnodes.add(node)

        self._update_in_connectors(graph, intermediate_dnodes)
        self._replicate_first_map(sdfg, graph, first_map_entry,
                                  intermediate_dnodes)

        graph.remove_nodes_from(
            graph.all_nodes_between(first_map_entry, self.first_map_exit)
            | {first_map_entry, self.first_map_exit})

        for node in graph.nodes():
            if not isinstance(node, nds.AccessNode):
                continue

            if graph.in_degree(node) == 0 and graph.out_degree(node) == 0:
                graph.remove_node(node)
Exemple #14
0
    def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode,
                    outer_stream: CodeIOStream, inner_stream: CodeIOStream,
                    global_stream: CodeIOStream):
        from dace.codegen.dispatcher import DefinedType  # Avoid import loop

        if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(
                sdfg, state, node):
            # Only run on host code
            return

        desc = node.desc(sdfg)

        # Obtain a pointer for arrays and scalars
        ptrname = cpp.ptr(node.data, desc, sdfg, self.codegen)
        defined_type, _ = self.codegen.dispatcher.defined_vars.get(ptrname)
        if defined_type == DefinedType.Scalar:
            ptrname = '&' + ptrname

        # Create UUID
        state_id = sdfg.node_id(state)
        node_id = state.node_id(node)
        uuid = f'{sdfg.sdfg_id}_{state_id}_{node_id}'

        # Get optional pre/postamble for instrumenting device data
        preamble, postamble = '', ''
        if desc.storage == dtypes.StorageType.GPU_Global:
            self._setup_gpu_runtime(sdfg, global_stream)
            preamble, postamble, ptrname = self._generate_copy_to_host(
                node, desc, ptrname)

        # Encode runtime shape and strides
        shape = ', '.join(cpp.sym2cpp(s) for s in desc.shape)
        strides = ', '.join(cpp.sym2cpp(s) for s in desc.strides)

        # Write code
        inner_stream.write(preamble, sdfg, state_id, node_id)
        inner_stream.write(
            f'__state->serializer->save({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, '
            f'"{node.data}", "{uuid}", {shape}, {strides});\n', sdfg, state_id,
            node_id)
        inner_stream.write(postamble, sdfg, state_id, node_id)
Exemple #15
0
def infer_tasklet_connectors(sdfg: SDFG, state: SDFGState, node: Tasklet,
                             inferred: TypeInferenceDict):
    """ Infers the connectors in a Tasklet using its code and type inference. """

    if node.code.language != dtypes.Language.Python:
        raise NotImplementedError(
            'Tasklet inference for other languages than Python not supported')

    if any(inferred[(node, conn, True)].type is None
           for conn in node.in_connectors):
        raise TypeError('Cannot infer output connectors of tasklet "%s", '
                        'not all input connectors have types' % str(node))

    # Avoid import loop
    from dace.codegen.tools.type_inference import infer_types

    # Get symbols defined at beginning of node
    syms = state.symbols_defined_at(node)

    in_syms = {}
    for conn in node.in_connectors:
        if inferred[(node, conn, True)].type is not None:
            # Let the inferred dictionary win if it has some type information
            in_syms[conn] = inferred[(node, conn, True)]
        else:
            in_syms[conn] = node.in_connectors[conn]

    syms.update(in_syms)

    # Infer all types in tasklet
    new_syms = infer_types(node.code.code, syms)

    for cname in node.out_connectors:
        if inferred[(node, cname, False)].type is None:
            if cname not in new_syms:
                raise TypeError('Cannot infer type of tasklet %s output '
                                '"%s", please specify manually.' %
                                (node.label, cname))
            inferred[(node, cname, False)] = new_syms[cname]
Exemple #16
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        tile_strides = self.tile_sizes
        if self.strides is not None and len(self.strides) == len(tile_strides):
            tile_strides = self.strides

        # Retrieve map entry and exit nodes.
        map_entry = self.map_entry
        from dace.transformation.dataflow.map_collapse import MapCollapse
        from dace.transformation.dataflow.strip_mining import StripMining
        stripmine_subgraph = {
            StripMining.map_entry: self.subgraph[MapTiling.map_entry]
        }
        sdfg_id = sdfg.sdfg_id
        last_map_entry = None
        removed_maps = 0

        original_schedule = map_entry.schedule

        for dim_idx in range(len(map_entry.map.params)):
            if dim_idx >= len(self.tile_sizes):
                tile_size = symbolic.pystr_to_symbolic(self.tile_sizes[-1])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[-1])
            else:
                tile_size = symbolic.pystr_to_symbolic(
                    self.tile_sizes[dim_idx])
                tile_stride = symbolic.pystr_to_symbolic(tile_strides[dim_idx])

            # handle offsets
            if self.tile_offset and dim_idx >= len(self.tile_offset):
                offset = self.tile_offset[-1]
            elif self.tile_offset:
                offset = self.tile_offset[dim_idx]
            else:
                offset = 0

            dim_idx -= removed_maps
            # If tile size is trivial, skip strip-mining map dimension
            if tile_size == map_entry.map.range.size()[dim_idx]:
                continue

            stripmine = StripMining(sdfg, sdfg_id, self.state_id,
                                    stripmine_subgraph, self.expr_index)

            # Special case: Tile size of 1 should be omitted from inner map
            if tile_size == 1 and tile_stride == 1 and self.tile_trivial == False:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = ''
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = True
                stripmine.tile_offset = str(offset)
                stripmine.apply(graph, sdfg)
                removed_maps += 1
            else:
                stripmine.dim_idx = dim_idx
                stripmine.new_dim_prefix = self.prefix
                stripmine.tile_size = str(tile_size)
                stripmine.tile_stride = str(tile_stride)
                stripmine.divides_evenly = self.divides_evenly
                stripmine.tile_offset = str(offset)
                stripmine.apply(graph, sdfg)

            # apply to the new map the schedule of the original one
            map_entry.schedule = original_schedule

            if last_map_entry:
                new_map_entry = graph.in_edges(map_entry)[0].src
                mapcollapse_subgraph = {
                    MapCollapse.outer_map_entry: graph.node_id(last_map_entry),
                    MapCollapse.inner_map_entry: graph.node_id(new_map_entry)
                }
                mapcollapse = MapCollapse(sdfg, sdfg_id, self.state_id,
                                          mapcollapse_subgraph, 0)
                mapcollapse.apply(graph, sdfg)
            last_map_entry = graph.in_edges(map_entry)[0].src
        return last_map_entry
Exemple #17
0
def get_view_edge(state: SDFGState,
                  view: nd.AccessNode) -> gr.MultiConnectorEdge[mm.Memlet]:
    """
    Given a view access node, returns the 
    incoming/outgoing edge which points to the viewed access node.
    See the ruleset in the documentation of ``dace.data.View``.

    :param state: The state in which the view resides.
    :param view: The view access node.
    :return: An edge pointing to the viewed data or None if view is invalid.
    :see: ``dace.data.View``
    """

    in_edges = state.in_edges(view)
    out_edges = state.out_edges(view)

    # Invalid case: No data to view
    if len(in_edges) == 0 or len(out_edges) == 0:
        return None

    # If there is one edge (in/out) that leads (via memlet path) to an access
    # node, and the other side (out/in) has a different number of edges.
    if len(in_edges) == 1 and len(out_edges) != 1:
        return in_edges[0]
    if len(out_edges) == 1 and len(in_edges) != 1:
        return out_edges[0]
    if len(out_edges) == len(in_edges) and len(out_edges) != 1:
        return None

    in_edge = in_edges[0]
    out_edge = out_edges[0]

    # If there is one incoming and one outgoing edge, and one leads to a code
    # node, the one that leads to an access node is the viewed data.
    inmpath = state.memlet_path(in_edge)
    outmpath = state.memlet_path(out_edge)
    src_is_data, dst_is_data = False, False
    if isinstance(inmpath[0].src, nd.AccessNode):
        src_is_data = True
    if isinstance(outmpath[-1].dst, nd.AccessNode):
        dst_is_data = True

    if src_is_data and not dst_is_data:
        return in_edge
    if not src_is_data and dst_is_data:
        return out_edge
    if not src_is_data and not dst_is_data:
        return None

    # If both sides lead to access nodes, if one memlet's data points to the
    # view it cannot point to the viewed node.
    if in_edge.data.data == view.data and out_edge.data.data != view.data:
        return out_edge
    if in_edge.data.data != view.data and out_edge.data.data == view.data:
        return in_edge
    if in_edge.data.data == view.data and out_edge.data.data == view.data:
        return None

    # If both memlets' data are the respective access nodes, the access
    # node at the highest scope is the one that is viewed.
    if isinstance(in_edge.src, nd.EntryNode):
        return in_edge
    if isinstance(out_edge.dst, nd.ExitNode):
        return out_edge

    # If both access nodes reside in the same scope, the input data is viewed.
    return in_edge
Exemple #18
0
def infer_node_connectors(sdfg: SDFG, state: SDFGState, node: nodes.Node,
                          inferred: TypeInferenceDict):
    """ Infers the connectors of a node and updates `inferred` accordingly. """

    # Try to infer input connector type from node type or previous edges
    for e in state.in_edges(node):
        cname = e.dst_conn
        if cname is None:
            continue

        scalar = (e.data.subset and e.data.subset.num_elements() == 1)
        if e.data.data is not None:
            allocated_as_scalar = (sdfg.arrays[e.data.data].storage
                                   is not dtypes.StorageType.GPU_Global)
        else:
            allocated_as_scalar = True

        if inferred[(node, cname, True)].type is None:
            # If nested SDFG, try to use internal array type
            if isinstance(node, nodes.NestedSDFG):
                scalar = (isinstance(node.sdfg.arrays[cname], data.Scalar)
                          and allocated_as_scalar)
                dtype = node.sdfg.arrays[cname].dtype
                ctype = (dtype if scalar else dtypes.pointer(dtype))
            elif e.data.data is not None:  # Obtain type from memlet
                scalar |= isinstance(sdfg.arrays[e.data.data], data.Scalar)
                if isinstance(node, nodes.LibraryNode):
                    scalar &= allocated_as_scalar
                dtype = sdfg.arrays[e.data.data].dtype
                ctype = (dtype if scalar else dtypes.pointer(dtype))
            else:  # Code->Code
                src_edge = state.memlet_path(e)[0]
                sconn = src_edge.src.out_connectors[src_edge.src_conn]
                if sconn.type is None:
                    raise TypeError('Ambiguous or uninferable type in'
                                    ' connector "%s" of node "%s"' %
                                    (sconn, src_edge.src))
                ctype = sconn
            inferred[(node, cname, True)] = ctype

    # Try to infer outputs from output edges
    for e in state.out_edges(node):
        cname = e.src_conn
        if cname is None:
            continue

        scalar = (e.data.subset and e.data.subset.num_elements() == 1
                  and (not e.data.dynamic or
                       (e.data.dynamic and e.data.wcr is not None)))
        if e.data.data is not None:
            allocated_as_scalar = (sdfg.arrays[e.data.data].storage
                                   is not dtypes.StorageType.GPU_Global)
        else:
            allocated_as_scalar = True

        if inferred[(node, cname, False)].type is None:
            # If nested SDFG, try to use internal array type
            if isinstance(node, nodes.NestedSDFG):
                scalar = (isinstance(node.sdfg.arrays[cname], data.Scalar)
                          and allocated_as_scalar)
                dtype = node.sdfg.arrays[cname].dtype
                ctype = (dtype if scalar else dtypes.pointer(dtype))
            elif e.data.data is not None:  # Obtain type from memlet
                scalar |= isinstance(sdfg.arrays[e.data.data], data.Scalar)
                if isinstance(node, nodes.LibraryNode):
                    scalar &= allocated_as_scalar
                dtype = sdfg.arrays[e.data.data].dtype
                ctype = (dtype if scalar else dtypes.pointer(dtype))
            else:
                continue
            inferred[(node, cname, False)] = ctype

    # Let the node infer other output types on its own
    if isinstance(node, nodes.Tasklet):
        infer_tasklet_connectors(sdfg, state, node, inferred)
    elif isinstance(node, nodes.NestedSDFG):
        infer_connector_types(node.sdfg, inferred=inferred)

    # If there are any remaining uninferable connectors, fail
    for e in state.out_edges(node):
        cname = e.src_conn
        if cname and inferred[(node, cname, False)].type is None:
            raise TypeError('Ambiguous or uninferable type in'
                            ' connector "%s" of node "%s"' % (cname, node))
Exemple #19
0
def merge_maps(
    graph: SDFGState,
    outer_map_entry: nd.MapEntry,
    outer_map_exit: nd.MapExit,
    inner_map_entry: nd.MapEntry,
    inner_map_exit: nd.MapExit,
    param_merge: Callable[[ParamsType, ParamsType],
                          ParamsType] = lambda p1, p2: p1 + p2,
    range_merge: Callable[[RangesType, RangesType],
                          RangesType] = lambda r1, r2: type(r1)
    (r1.ranges + r2.ranges)
) -> (nd.MapEntry, nd.MapExit):
    """ Merges two maps (their entries and exits). It is assumed that the
    operation is valid. """

    outer_map = outer_map_entry.map
    inner_map = inner_map_entry.map

    # Create merged map by inheriting attributes from outer map and using
    # the merge functions for parameters and ranges.
    merged_map = copy.deepcopy(outer_map)
    merged_map.label = outer_map.label
    merged_map.params = param_merge(outer_map.params, inner_map.params)
    merged_map.range = range_merge(outer_map.range, inner_map.range)

    merged_entry = nd.MapEntry(merged_map)
    merged_entry.in_connectors = outer_map_entry.in_connectors
    merged_entry.out_connectors = outer_map_entry.out_connectors

    merged_exit = nd.MapExit(merged_map)
    merged_exit.in_connectors = outer_map_exit.in_connectors
    merged_exit.out_connectors = outer_map_exit.out_connectors

    graph.add_nodes_from([merged_entry, merged_exit])

    # Handle the case of dynamic map inputs in the inner map
    inner_dynamic_map_inputs = dynamic_map_inputs(graph, inner_map_entry)
    for edge in inner_dynamic_map_inputs:
        remove_conn = (len(
            list(graph.out_edges_by_connector(edge.src, edge.src_conn))) == 1)
        conn_to_remove = edge.src_conn[4:]
        if remove_conn:
            merged_entry.remove_in_connector('IN_' + conn_to_remove)
            merged_entry.remove_out_connector('OUT_' + conn_to_remove)
        merged_entry.add_in_connector(
            edge.dst_conn, inner_map_entry.in_connectors[edge.dst_conn])
        outer_edge = next(
            graph.in_edges_by_connector(outer_map_entry,
                                        'IN_' + conn_to_remove))
        graph.add_edge(outer_edge.src, outer_edge.src_conn, merged_entry,
                       edge.dst_conn, outer_edge.data)
        if remove_conn:
            graph.remove_edge(outer_edge)

    # Redirect inner in edges.
    for edge in graph.out_edges(inner_map_entry):
        if edge.src_conn is None:  # Empty memlets
            graph.add_edge(merged_entry, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)
            continue

        # Get memlet path and edge
        path = graph.memlet_path(edge)
        ind = path.index(edge)
        # Add an edge directly from the previous source connector to the
        # destination
        graph.add_edge(merged_entry, path[ind - 1].src_conn, edge.dst,
                       edge.dst_conn, edge.data)

    # Redirect inner out edges.
    for edge in graph.in_edges(inner_map_exit):
        if edge.dst_conn is None:  # Empty memlets
            graph.add_edge(edge.src, edge.src_conn, merged_exit, edge.dst_conn,
                           edge.data)
            continue

        # Get memlet path and edge
        path = graph.memlet_path(edge)
        ind = path.index(edge)
        # Add an edge directly from the source to the next destination
        # connector
        graph.add_edge(edge.src, edge.src_conn, merged_exit,
                       path[ind + 1].dst_conn, edge.data)

    # Redirect outer edges.
    change_edge_dest(graph, outer_map_entry, merged_entry)
    change_edge_src(graph, outer_map_exit, merged_exit)

    # Clean-up
    graph.remove_nodes_from(
        [outer_map_entry, outer_map_exit, inner_map_entry, inner_map_exit])

    return merged_entry, merged_exit
Exemple #20
0
def top_level_nodes(state: SDFGState):
    return state.scope_children()[None]
Exemple #21
0
    def can_be_applied(cls,
                       state: SDFGState,
                       candidate,
                       expr_index,
                       sdfg: SDFG,
                       strict=False) -> bool:
        map_entry = state.node(candidate[cls.map_entry])
        map_exit = state.exit_node(map_entry)
        current_map = map_entry.map
        subgraph = state.scope_subgraph(map_entry)
        subgraph_contents = state.scope_subgraph(map_entry,
                                                 include_entry=False,
                                                 include_exit=False)

        # Prevent infinite repeats
        if current_map.schedule == dace.dtypes.ScheduleType.SVE_Map:
            return False

        # Infer all connector types for later checks (without modifying the graph)
        inferred = infer_types.infer_connector_types(sdfg, state, subgraph)

        ########################
        # Ensure only Tasklets and AccessNodes are within the map
        for node, _ in subgraph_contents.all_nodes_recursive():
            if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)):
                return False

        ########################
        # Check for unsupported datatypes on the connectors (including on the Map itself)
        bit_widths = set()
        for node, _ in subgraph.all_nodes_recursive():
            for conn in node.in_connectors:
                t = inferred[(node, conn, True)]
                bit_widths.add(util.get_base_type(t).bytes)
                if not t.type in sve.util.TYPE_TO_SVE:
                    return False
            for conn in node.out_connectors:
                t = inferred[(node, conn, False)]
                bit_widths.add(util.get_base_type(t).bytes)
                if not t.type in sve.util.TYPE_TO_SVE:
                    return False

        # Multiple different bit widths occuring (messes up the predicates)
        if len(bit_widths) > 1:
            return False

        ########################
        # Check for unsupported memlets
        param_name = current_map.params[-1]
        for e, _ in subgraph.all_edges_recursive():
            # Check for unsupported strides
            # The only unsupported strides are the ones containing the innermost
            # loop param because they are not constant during a vector step
            param_sym = symbolic.symbol(current_map.params[-1])

            if param_sym in e.data.get_stride(sdfg,
                                              map_entry.map).free_symbols:
                return False

            # Check for unsupported WCR
            if e.data.wcr is not None:
                # Unsupported reduction type
                reduction_type = dace.frontend.operations.detect_reduction_type(
                    e.data.wcr)
                if reduction_type not in sve.util.REDUCTION_TYPE_TO_SVE:
                    return False

                # Param in memlet during WCR is not supported
                if param_name in e.data.subset.free_symbols and e.data.wcr_nonatomic:
                    return False

                # vreduce is not supported
                dst_node = state.memlet_path(e)[-1]
                if isinstance(dst_node, nodes.Tasklet):
                    if isinstance(dst_node.in_connectors[e.dst_conn],
                                  dtypes.vector):
                        return False
                elif isinstance(dst_node, nodes.AccessNode):
                    desc = dst_node.desc(sdfg)
                    if isinstance(desc, data.Scalar) and isinstance(
                            desc.dtype, dtypes.vector):
                        return False

        ########################
        # Check for invalid copies in the subgraph
        for node, _ in subgraph.all_nodes_recursive():
            if not isinstance(node, nodes.Tasklet):
                continue

            for e in state.in_edges(node):
                # Check for valid copies from other tasklets and/or streams
                if e.data.data is not None:
                    src_node = state.memlet_path(e)[0].src
                    if not isinstance(src_node,
                                      (nodes.Tasklet, nodes.AccessNode)):
                        # Make sure we only have Code->Code copies and from arrays
                        return False

                    if isinstance(src_node, nodes.AccessNode):
                        src_desc = src_node.desc(sdfg)
                        if isinstance(src_desc, dace.data.Stream):
                            # Stream pops are not implemented
                            return False

        # Run the vector inference algorithm to check if vectorization is feasible
        try:
            inf_graph = vector_inference.infer_vectors(
                sdfg,
                state,
                map_entry,
                util.SVE_LEN,
                flags=vector_inference.VectorInferenceFlags.Allow_Stride,
                apply=False)
        except vector_inference.VectorInferenceException as ex:
            print(f'UserWarning: Vector inference failed! {ex}')
            return False

        return True
Exemple #22
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        map_entry = self.map_entry

        # Avoiding import loops
        from dace.transformation.dataflow.strip_mining import StripMining
        from dace.transformation.dataflow.local_storage import InLocalStorage, OutLocalStorage, LocalStorage

        rangeexpr = str(map_entry.map.range.num_elements())

        stripmine_subgraph = {
            StripMining.map_entry: self.subgraph[MPITransformMap.map_entry]
        }
        sdfg_id = sdfg.sdfg_id
        stripmine = StripMining(sdfg, sdfg_id, self.state_id,
                                stripmine_subgraph, self.expr_index)
        stripmine.dim_idx = -1
        stripmine.new_dim_prefix = "mpi"
        stripmine.tile_size = "(" + rangeexpr + "/__dace_comm_size)"
        stripmine.divides_evenly = True
        stripmine.apply(graph, sdfg)

        # Find all in-edges that lead to the map entry
        outer_map = None
        edges = [
            e for e in graph.in_edges(map_entry)
            if isinstance(e.src, nodes.EntryNode)
        ]

        outer_map = edges[0].src

        # Add MPI schedule attribute to outer map
        outer_map.map._schedule = dtypes.ScheduleType.MPI

        # Now create a transient for each array
        for e in edges:
            in_local_storage_subgraph = {
                LocalStorage.node_a: graph.node_id(outer_map),
                LocalStorage.node_b: self.subgraph[MPITransformMap.map_entry]
            }
            sdfg_id = sdfg.sdfg_id
            in_local_storage = InLocalStorage(sdfg, sdfg_id, self.state_id,
                                              in_local_storage_subgraph,
                                              self.expr_index)
            in_local_storage.array = e.data.data
            in_local_storage.apply(graph, sdfg)

        # Transform OutLocalStorage for each output of the MPI map
        in_map_exit = graph.exit_node(map_entry)
        out_map_exit = graph.exit_node(outer_map)

        for e in graph.out_edges(out_map_exit):
            name = e.data.data
            outlocalstorage_subgraph = {
                LocalStorage.node_a: graph.node_id(in_map_exit),
                LocalStorage.node_b: graph.node_id(out_map_exit)
            }
            sdfg_id = sdfg.sdfg_id
            outlocalstorage = OutLocalStorage(sdfg, sdfg_id, self.state_id,
                                              outlocalstorage_subgraph,
                                              self.expr_index)
            outlocalstorage.array = name
            outlocalstorage.apply(graph, sdfg)
Exemple #23
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        """
            This method applies the mapfusion transformation.
            Other than the removal of the second map entry node (SME), and the first
            map exit (FME) node, it has the following side effects:

            1.  Any transient adjacent to both FME and SME with degree = 2 will be removed.
                The tasklets that use/produce it shall be connected directly with a
                scalar/new transient (if the dataflow is more than a single scalar)

            2.  If this transient is adjacent to FME and SME and has other
                uses, it will be adjacent to the new map exit post fusion.
                Tasklet-> Tasklet edges will ALSO be added as mentioned above.

            3.  If an access node is adjacent to FME but not SME, it will be
                adjacent to new map exit post fusion.

            4.  If an access node is adjacent to SME but not FME, it will be
                adjacent to the new map entry node post fusion.

        """
        first_exit = self.first_map_exit
        first_entry = graph.entry_node(first_exit)
        second_entry = self.second_map_entry
        second_exit = graph.exit_node(second_entry)

        intermediate_nodes = set()
        for _, _, dst, _, _ in graph.out_edges(first_exit):
            intermediate_nodes.add(dst)
            assert isinstance(dst, nodes.AccessNode)

        # Check if an access node refers to non transient memory, or transient
        # is used at another location (cannot erase)
        do_not_erase = set()
        for node in intermediate_nodes:
            if sdfg.arrays[node.data].transient is False:
                do_not_erase.add(node)
            else:
                for edge in graph.in_edges(node):
                    if edge.src != first_exit:
                        do_not_erase.add(node)
                        break
                else:
                    for edge in graph.out_edges(node):
                        if edge.dst != second_entry:
                            do_not_erase.add(node)
                            break

        # Find permutation between first and second scopes
        perm = self.find_permutation(first_entry.map, second_entry.map)
        params_dict = {}
        for index, param in enumerate(first_entry.map.params):
            params_dict[param] = second_entry.map.params[perm[index]]

        # Replaces (in memlets and tasklet) the second scope map
        # indices with the permuted first map indices.
        # This works in two passes to avoid problems when e.g., exchanging two
        # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then
        # i,i).
        second_scope = graph.scope_subgraph(second_entry)
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, secondp, '__' + secondp + '_fused')
        for firstp, secondp in params_dict.items():
            if firstp != secondp:
                replace(second_scope, '__' + secondp + '_fused', firstp)

        # Isolate First exit node
        ############################
        edges_to_remove = set()
        nodes_to_remove = set()
        for edge in graph.in_edges(first_exit):
            tree = graph.memlet_tree(edge)
            access_node = tree.root().edge.dst
            if access_node not in do_not_erase:
                out_edges = [
                    e for e in graph.out_edges(access_node)
                    if e.dst == second_entry
                ]
                # In this transformation, there can only be one edge to the
                # second map
                assert len(out_edges) == 1

                # Get source connector to the second map
                connector = out_edges[0].dst_conn[3:]

                new_dsts = []
                # Look at the second map entry out-edges to get the new
                # destinations
                for e in graph.out_edges(second_entry):
                    if e.src_conn[4:] == connector:
                        new_dsts.append(e)
                if not new_dsts:  # Access node is not used in the second map
                    nodes_to_remove.add(access_node)
                    continue

                # Add a transient scalar/array
                self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst,
                                new_dsts[0].dst_conn, new_dsts[1:])

                edges_to_remove.add(edge)

                # Remove transient node between the two maps
                nodes_to_remove.add(access_node)
            else:  # The case where intermediate array node cannot be removed
                # Node will become an output of the second map exit
                out_e = tree.parent.edge
                conn = second_exit.next_connector()
                graph.add_edge(
                    second_exit,
                    'OUT_' + conn,
                    out_e.dst,
                    out_e.dst_conn,
                    dcpy(out_e.data),
                )
                second_exit.add_out_connector('OUT_' + conn)

                graph.add_edge(edge.src, edge.src_conn, second_exit,
                               'IN_' + conn, dcpy(edge.data))
                second_exit.add_in_connector('IN_' + conn)

                edges_to_remove.add(out_e)
                edges_to_remove.add(edge)

                # If the second map needs this node, link the connector
                # that generated this to the place where it is needed, with a
                # temp transient/scalar for memlet to be generated
                for out_e in graph.out_edges(second_entry):
                    second_memlet_path = graph.memlet_path(out_e)
                    source_node = second_memlet_path[0].src
                    if source_node == access_node:
                        self.fuse_nodes(sdfg, graph, edge, out_e.dst,
                                        out_e.dst_conn)

        ###
        # First scope exit is isolated and can now be safely removed
        for e in edges_to_remove:
            graph.remove_edge(e)
        graph.remove_nodes_from(nodes_to_remove)
        graph.remove_node(first_exit)

        # Isolate second_entry node
        ###########################
        for edge in graph.in_edges(second_entry):
            tree = graph.memlet_tree(edge)
            access_node = tree.root().edge.src
            if access_node in intermediate_nodes:
                # Already handled above, can be safely removed
                graph.remove_edge(edge)
                continue

            # This is an external input to the second map which will now go
            # through the first map.
            conn = first_entry.next_connector()
            graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn,
                           dcpy(edge.data))
            first_entry.add_in_connector('IN_' + conn)
            graph.remove_edge(edge)
            for out_enode in tree.children:
                out_e = out_enode.edge
                graph.add_edge(
                    first_entry,
                    'OUT_' + conn,
                    out_e.dst,
                    out_e.dst_conn,
                    dcpy(out_e.data),
                )
                graph.remove_edge(out_e)
            first_entry.add_out_connector('OUT_' + conn)

        ###
        # Second node is isolated and can now be safely removed
        graph.remove_node(second_entry)

        # Fix scope exit to point to the right map
        second_exit.map = first_entry.map
Exemple #24
0
def _loop_from_structure(
        sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
        leave_edge: Edge[InterstateEdge],
        back_edges: List[Edge[InterstateEdge]],
        dispatch_state: Callable[[SDFGState],
                                 str]) -> Union[ForScope, WhileScope]:
    """ 
    Helper method that constructs the correct structured loop construct from a
    set of states. Can construct for or while loops.
    """

    body = GeneralBlock(dispatch_state, [], [])

    guard_inedges = sdfg.in_edges(guard)
    increment_edges = [e for e in guard_inedges if e in back_edges]
    init_edges = [e for e in guard_inedges if e not in back_edges]

    # If no back edge found (or more than one, indicating a "continue"
    # statement), disregard
    if len(increment_edges) > 1 or len(increment_edges) == 0:
        return None
    increment_edge = increment_edges[0]

    # Mark increment edge to be ignored in body
    body.edges_to_ignore.append(increment_edge)

    # Outgoing edges must be a negation of each other
    if enter_edge.data.condition_sympy() != (sp.Not(
            leave_edge.data.condition_sympy())):
        return None

    # Body of guard state must be empty
    if not guard.is_empty():
        return None

    if not increment_edge.data.is_unconditional():
        return None
    if len(enter_edge.data.assignments) > 0:
        return None

    condition = enter_edge.data.condition

    # Detect whether this loop is a for loop:
    # All incoming edges to the guard must set the same variable
    itvars = None
    for iedge in guard_inedges:
        if itvars is None:
            itvars = set(iedge.data.assignments.keys())
        else:
            itvars &= iedge.data.assignments.keys()
    if itvars and len(itvars) == 1:
        itvar = next(iter(itvars))
        init = init_edges[0].data.assignments[itvar]

        # Check that all init edges are the same and that increment edge only
        # increments
        if (all(e.data.assignments[itvar] == init for e in init_edges)
                and len(increment_edge.data.assignments) == 1):
            update = increment_edge.data.assignments[itvar]
            return ForScope(dispatch_state, itvar, guard, init, condition,
                            update, body)

    # Otherwise, it is a while loop
    return WhileScope(dispatch_state, guard, condition, body)
Exemple #25
0
def trace_nested_access(
        node: nd.AccessNode, state: SDFGState,
        sdfg: SDFG) -> List[Tuple[nd.AccessNode, SDFGState, SDFG]]:
    """
    Given an AccessNode in a nested SDFG, trace the accessed memory
    back to the outermost scope in which it is defined.

    :param node: An access node.
    :param state: State in which the access node is located.
    :param sdfg: SDFG in which the access node is located.
    :return: A list of scopes ((input_node, output_node), (memlet_read, memlet_write), state, sdfg) in which
             the given data is accessed, from outermost scope to innermost
             scope.
    """
    curr_sdfg = sdfg
    curr_read = None
    memlet_read = None
    for m in state.out_edges(node):
        if not m.data.is_empty():
            curr_read = node
            memlet_read = m.data
            break

    curr_write = None
    memlet_write = None
    for m in state.in_edges(node):
        if not m.data.is_empty():
            curr_write = node
            memlet_read = m.data
            break

    trace = [((curr_read, curr_write), (memlet_read, memlet_write), state, sdfg)
             ]

    while curr_sdfg.parent is not None:
        curr_state = curr_sdfg.parent

        # Find the nested SDFG containing ourself in the parent state
        nested_sdfg = curr_sdfg.parent_nsdfg_node

        if curr_read is not None:
            for e in curr_state.in_edges(nested_sdfg):
                if e.dst_conn == curr_read.data:
                    # See if the input to this connector traces back to an
                    # access node. If not, just give up here
                    n = find_input_arraynode(curr_state, e)
                    if isinstance(n, nd.AccessNode):
                        curr_read = n
                        memlet_read = e.data
                        break
            else:
                curr_read = None
                memlet_read = None
        if curr_write is not None:
            for e in curr_state.out_edges(nested_sdfg):
                if e.src_conn == curr_write.data:
                    # See if the output of this connector traces back to an
                    # access node. If not, just give up here
                    n = find_output_arraynode(curr_state, e)
                    if isinstance(curr_write, nd.AccessNode):
                        curr_write = n
                        memlet_write = e.data
                        break
            else:
                curr_write = None
                memlet_write = None
        if curr_read is not None or curr_write is not None:
            trace.append(((curr_read, curr_write), (memlet_read, memlet_write),
                          curr_state, curr_state.parent))
        else:
            break
        curr_sdfg = curr_state.parent  # Recurse
    return list(reversed(trace))
Exemple #26
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        node_a = self.node_a
        node_b = self.node_b
        prefix = self.prefix

        # Determine direction of new memlet
        scope_dict = graph.scope_dict()
        propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b)

        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(node_a, node_b)
                         if e.data.data is not None and e.data.wcr is None)

        original_edge = None
        invariant_memlet = None
        for edge in graph.edges_between(node_a, node_b):
            if array == edge.data.data:
                original_edge = edge
                invariant_memlet = edge.data
                break
        if invariant_memlet is None:
            for edge in graph.edges_between(node_a, node_b):
                original_edge = edge
                invariant_memlet = edge.data
                warnings.warn('Array %s not found! Using array %s instead.' %
                              (array, invariant_memlet.data))
                array = invariant_memlet.data
                break
        if invariant_memlet is None:
            raise NameError('Array %s not found!' % array)
        if self.create_array:
            # Add transient array
            new_data, _ = sdfg.add_transient(
                name=prefix + invariant_memlet.data,
                shape=[
                    symbolic.overapproximate(r).simplify()
                    for r in invariant_memlet.bounding_box_size()
                ],
                dtype=sdfg.arrays[invariant_memlet.data].dtype,
                find_new_name=True)

        else:
            new_data = prefix + invariant_memlet.data
        data_node = nodes.AccessNode(new_data)
        # Store as fields so that other transformations can use them
        self._local_name = new_data
        self._data_node = data_node

        to_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm = copy.deepcopy(invariant_memlet)
        offset = subsets.Indices([r[0] for r in invariant_memlet.subset])

        # Reconnect, assuming one edge to the access node
        graph.remove_edge(original_edge)
        if propagate_forward:
            graph.add_edge(node_a, original_edge.src_conn, data_node, None,
                           to_data_mm)
            new_edge = graph.add_edge(data_node, None, node_b,
                                      original_edge.dst_conn, from_data_mm)
        else:
            new_edge = graph.add_edge(node_a, original_edge.src_conn,
                                      data_node, None, to_data_mm)
            graph.add_edge(data_node, None, node_b, original_edge.dst_conn,
                           from_data_mm)

        # Offset all edges in the memlet tree (including the new edge)
        for edge in graph.memlet_tree(new_edge):
            edge.data.subset.offset(offset, True)
            edge.data.data = new_data

        return data_node
Exemple #27
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        map_exit = self.map_exit
        outer_map_exit = self.outer_map_exit

        # Choose array
        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        data_node: nodes.AccessNode = OutLocalStorage.apply_to(
            sdfg,
            dict(array=array),
            verify=False,
            save=False,
            node_a=map_exit,
            node_b=outer_map_exit)

        if self.identity is None:
            warnings.warn('AccumulateTransient did not properly initialize '
                          'newly-created transient!')
            return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        temp_array: Array = sdfg.arrays[data_node.data]

        init_state.add_mapped_tasklet(
            name='acctrans_init',
            map_ranges={
                '_o%d' % i: '0:%s' % symstr(d)
                for i, d in enumerate(temp_array.shape)
            },
            inputs={},
            code='out = %s' % self.identity,
            outputs={
                'out':
                dace.Memlet.simple(data=data_node.data,
                                   subset_str=','.join([
                                       '_o%d' % i
                                       for i, _ in enumerate(temp_array.shape)
                                   ]))
            },
            external_edges=True)
    def apply(self, graph: SDFGState, sdfg: SDFG):
        import dace.libraries.blas as blas

        transpose_a = self.transpose_a
        _at = self.at
        transpose_b = self.transpose_b
        _bt = self.bt
        a_times_b = self.a_times_b

        for src, src_conn, _, _, memlet in graph.in_edges(transpose_a):
            graph.add_edge(src, src_conn, a_times_b, '_b', memlet)
        graph.remove_node(transpose_a)
        for src, src_conn, _, _, memlet in graph.in_edges(transpose_b):
            graph.add_edge(src, src_conn, a_times_b, '_a', memlet)
        graph.remove_node(transpose_b)
        graph.remove_node(_at)
        graph.remove_node(_bt)

        for _, _, dst, dst_conn, memlet in graph.out_edges(a_times_b):
            subset = dcpy(memlet.subset)
            subset.squeeze()
            size = subset.size()
            shape = [size[1], size[0]]
            break
        tmp_name, tmp_arr = sdfg.add_temp_transient(shape, a_times_b.dtype)
        tmp_acc = graph.add_access(tmp_name)
        transpose_c = blas.Transpose('_Transpose_', a_times_b.dtype)
        for edge in graph.out_edges(a_times_b):
            _, _, dst, dst_conn, memlet = edge
            graph.remove_edge(edge)
            graph.add_edge(transpose_c, '_out', dst, dst_conn, memlet)
        graph.add_edge(a_times_b, '_c', tmp_acc, None, dace.Memlet.from_array(tmp_name, tmp_arr))
        graph.add_edge(tmp_acc, None, transpose_c, '_inp', dace.Memlet.from_array(tmp_name, tmp_arr))