Exemple #1
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        in_array = graph.nodes()[candidate[RedundantArray._in_array]]
        out_array = graph.nodes()[candidate[RedundantArray._out_array]]

        in_desc = in_array.desc(sdfg)
        out_desc = out_array.desc(sdfg)

        # Ensure out degree is one (only one target, which is out_array)
        if graph.out_degree(in_array) != 1:
            return False

        # Make sure that the candidate is a transient variable
        if not in_desc.transient:
            return False

        # Make sure that both arrays are using the same storage location
        # and are of the same type (e.g., Stream->Stream)
        if in_desc.storage != out_desc.storage:
            return False
        if type(in_desc) != type(out_desc):
            return False

        # Find occurrences in this and other states
        occurrences = []
        for state in sdfg.nodes():
            occurrences.extend([
                n for n in state.nodes()
                if isinstance(n, nodes.AccessNode) and n.desc(sdfg) == in_desc
            ])
        for isedge in sdfg.edges():
            if in_array.data in isedge.data.free_symbols:
                occurrences.append(isedge)

        if len(occurrences) > 1:
            return False

        # Only apply if arrays are of same shape (no need to modify subset)
        if len(in_desc.shape) != len(out_desc.shape) or any(
                i != o for i, o in zip(in_desc.shape, out_desc.shape)):
            return False

        if strict:
            # In strict mode, make sure the memlet covers the removed array
            edge = graph.edges_between(in_array, out_array)[0]
            if any(m != a
                   for m, a in zip(edge.data.subset.size(), in_desc.shape)):
                return False

        return True
Exemple #2
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        in_array = graph.nodes()[candidate[RedundantArray.in_array]]
        out_array = graph.nodes()[candidate[RedundantArray.out_array]]

        in_desc = in_array.desc(sdfg)
        out_desc = out_array.desc(sdfg)

        # Ensure out degree is one (only one target, which is out_array)
        if graph.out_degree(in_array) != 1:
            return False

        # Make sure that the candidate is a transient variable
        if not in_desc.transient:
            return False

        # 1. Get edge e1 and extract subsets for arrays A and B
        e1 = graph.edges_between(in_array, out_array)[0]
        a1_subset, b_subset = _validate_subsets(e1, sdfg.arrays)

        if strict:
            # In strict mode, make sure the memlet covers the removed array
            subset = copy.deepcopy(e1.data.subset)
            subset.squeeze()
            shape = [sz for sz in in_desc.shape if sz != 1]
            if any(m != a for m, a in zip(subset.size(), shape)):
                return False

            # NOTE: Library node check
            # The transformation must not apply in strict mode if in_array is
            # not a view, is output of a library node, and an access or a view
            # of out_desc is also input to the same library node.
            # The reason is that the application of the transformation will lead
            # to out_desc being both input and output of the library node.
            # We do not know if this is safe.

            # First find the true out_desc (in case out_array is a view).
            true_out_desc = out_desc
            if isinstance(out_desc, data.View):
                e = sdutil.get_view_edge(graph, out_array)
                if not e:
                    return False
                true_out_desc = sdfg.arrays[e.src.data]

            if not isinstance(in_desc, data.View):

                edges_to_check = []
                for a in graph.in_edges(in_array):
                    if isinstance(a.src, nodes.LibraryNode):
                        edges_to_check.append(a)
                    elif (isinstance(a.src, nodes.AccessNode)
                          and isinstance(sdfg.arrays[a.src.data], data.View)):
                        for b in graph.in_edges(a.src):
                            edges_to_check.append(graph.memlet_path(b)[0])

                for a in edges_to_check:
                    if isinstance(a.src, nodes.LibraryNode):
                        for b in graph.in_edges(a.src):
                            if isinstance(b.src, nodes.AccessNode):
                                desc = sdfg.arrays[b.src.data]
                                if isinstance(desc, data.View):
                                    e = sdutil.get_view_edge(graph, b.src)
                                    if not e:
                                        return False
                                    desc = sdfg.arrays[e.src.data]
                                    if desc is true_out_desc:
                                        return False

            # In strict mode, check if the state has two or more access nodes
            # for the output array. Definitely one of them (out_array) is a
            # write access. Therefore, there might be a RW, WR, or WW dependency.
            accesses = [
                n for n in graph.nodes() if isinstance(n, nodes.AccessNode)
                and n.desc(sdfg) == out_desc and n is not out_array
            ]
            if len(accesses) > 0:
                # We need to ensure that a data race will not happen if we
                # remove in_array.
                # First, we simplify the graph
                G = helpers.simplify_state(graph)
                # Loop over the accesses
                for a in accesses:
                    try:
                        has_bward_path = nx.has_path(G, a, out_array)
                    except NodeNotFound:
                        has_bward_path = nx.has_path(graph.nx, a, out_array)
                    try:
                        has_fward_path = nx.has_path(G, out_array, a)
                    except NodeNotFound:
                        has_fward_path = nx.has_path(graph.nx, out_array, a)
                    # If there is no path between the access nodes (disconnected
                    # components), then it is definitely possible to have data
                    # races. Abort.
                    if not (has_bward_path or has_fward_path):
                        return False
                    # If there is a forward path then a must not be a direct
                    # successor of in_array.
                    if has_bward_path and out_array in G.successors(a):
                        return False

        # Make sure that both arrays are using the same storage location
        # and are of the same type (e.g., Stream->Stream)
        if in_desc.storage != out_desc.storage:
            return False
        if in_desc.location != out_desc.location:
            return False
        if type(in_desc) != type(out_desc):
            if isinstance(in_desc, data.View):
                # Case View -> Access
                # If the View points to the Access and has the same shape,
                # it can be removed, unless there is a reduction!
                e = sdutil.get_view_edge(graph, in_array)
                if e and e.dst is out_array and in_desc.shape == out_desc.shape:
                    from dace.libraries.standard import Reduce
                    for e in graph.in_edges(in_array):
                        if isinstance(e.src, Reduce):
                            return False
                    return True
                return False
            elif isinstance(out_desc, data.View):
                # Case Access -> View
                # If the View points to the Access (and has a different shape?)
                # then we should (probably) not remove the Access.
                e = sdutil.get_view_edge(graph, out_array)
                if e and e.src is in_array and in_desc.shape != out_desc.shape:
                    return False
                # Check that the View's immediate successors are Accesses.
                # Otherwise, the application of the transformation will result
                # in an ambiguous View.
                view_successors_desc = [
                    e.dst.desc(sdfg)
                    if isinstance(e.dst, nodes.AccessNode) else None
                    for e in graph.out_edges(out_array)
                ]
                if any([
                        not desc or isinstance(desc, data.View)
                        for desc in view_successors_desc
                ]):
                    return False
            else:
                # Something else, for example, Stream
                return False
        else:
            # Two views connected to each other
            if isinstance(in_desc, data.View):
                return True

        # Find occurrences in this and other states
        occurrences = []
        for state in sdfg.nodes():
            occurrences.extend([
                n for n in state.nodes()
                if isinstance(n, nodes.AccessNode) and n.desc(sdfg) == in_desc
            ])
        for isedge in sdfg.edges():
            if in_array.data in isedge.data.free_symbols:
                occurrences.append(isedge)

        if len(occurrences) > 1:
            return False

        # 2. Iterate over the e2 edges
        for e2 in graph.in_edges(in_array):
            # 2-a. Extract/validate subsets for array A and others
            try:
                _, a2_subset = _validate_subsets(e2, sdfg.arrays)
            except NotImplementedError:
                return False
            # 2-b. Check whether a2_subset covers a1_subset
            if not a2_subset.covers(a1_subset):
                return False
            # 2-c. Validate subsets in memlet tree
            # (should not be needed for valid SDGs)
            path = graph.memlet_tree(e2)
            for e3 in path:
                if e3 is not e2:
                    try:
                        _validate_subsets(e3,
                                          sdfg.arrays,
                                          dst_name=in_array.data)
                    except NotImplementedError:
                        return False

        return True