Example #1
0
    def _components(
            subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node]]:
        """
        Returns the list of tuples non-array components in this subgraph.
        Each element in the list is a 2 tuple of (input node, output node) of
        the component.
        """
        graph = (subgraph
                 if isinstance(subgraph, sd.SDFGState) else subgraph.graph)
        schildren = subgraph.scope_children()
        ns = [(n, graph.exit_node(n)) if isinstance(n, nodes.EntryNode) else
              (n, n) for n in schildren[None]
              if isinstance(n, (nodes.CodeNode, nodes.EntryNode))]

        return ns
Example #2
0
    def test_simple_program(self):
        @dace.program
        def multiply(a: dace.float32[N]):
            a *= 2
            a *= 3

        sdfg = multiply.to_sdfg(strict=True)
        for state in sdfg.nodes():
            if any(isinstance(node, Tasklet) for node in state.nodes()):
                break
        else:
            raise KeyError('State with tasklet not found')

        tasklet_nodes = [n for n in state.nodes() if isinstance(n, Tasklet)]
        with self.assertRaises(ValueError):
            nest_state_subgraph(sdfg, state,
                                SubgraphView(state, tasklet_nodes))

        nest_state_subgraph(sdfg, state, SubgraphView(state,
                                                      [tasklet_nodes[0]]))
        sdfg.validate()
        nest_state_subgraph(sdfg, state, SubgraphView(state,
                                                      [tasklet_nodes[1]]))
        sdfg.validate()
Example #3
0
def test_p1():

    N.set(20)
    M.set(30)
    O.set(50)
    P.set(40)
    Q.set(42)
    R.set(25)

    sdfg = subgraph_fusion_parallel.to_sdfg()
    sdfg.coarsen_dataflow()
    state = sdfg.nodes()[0]

    A = np.random.rand(N.get()).astype(np.float64)
    B = np.random.rand(M.get()).astype(np.float64)
    C = np.random.rand(O.get()).astype(np.float64)
    D = np.random.rand(M.get()).astype(np.float64)
    E = np.random.rand(N.get()).astype(np.float64)
    F = np.random.rand(P.get()).astype(np.float64)
    G = np.random.rand(M.get()).astype(np.float64)
    H = np.random.rand(P.get()).astype(np.float64)
    I = np.random.rand(N.get()).astype(np.float64)
    J = np.random.rand(R.get()).astype(np.float64)
    X = np.random.rand(N.get()).astype(np.float64)
    Y = np.random.rand(M.get()).astype(np.float64)
    Z = np.random.rand(P.get()).astype(np.float64)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, D=D, E=E, F=F, G=G, H=H, I=I, J=J, X=X, Y=Y, Z=Z,\
          N=N, M=M, O=O, P=P, R=R,Q=Q)
    del csdfg

    subgraph = SubgraphView(state, [node for node in state.nodes()])
    expansion = MultiExpansion(subgraph)
    fusion = SubgraphFusion(subgraph)

    me = MultiExpansion(subgraph)
    assert me.can_be_applied(sdfg, subgraph)
    me.apply(sdfg)

    sf = SubgraphFusion(subgraph)
    assert sf.can_be_applied(sdfg, subgraph)
    sf.apply(sdfg)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, D=D, E=E, F=F, G=G, H=H, I=I, J=J, X=X, Y=Y, Z=Z,\
          N=N, M=M, O=O, P=P, R=R,Q=Q)
    print("PASS")
Example #4
0
def subgraph_from_maps(sdfg, graph, map_entries, scope_children=None):
    """
    Given a list of map entries in a single graph,
    return a subgraph view that includes all nodes
    inside these maps as well as map entries and exits
    as well as adjacent nodes.
    """
    if not scope_children:
        scope_children = graph.scope_children()
    nodes = set()
    for map_entry in map_entries:
        nodes |= set(scope_children[map_entry])
        nodes |= set(e.dst for e in graph.out_edges(graph.exit_node(map_entry)))
        nodes |= set(e.src for e in graph.in_edges(map_entry))
        nodes.add(map_entry)

    return SubgraphView(graph, list(nodes))
Example #5
0
def test_p1():

    N.set(20)
    M.set(30)
    O.set(50)
    P.set(40)
    Q.set(42)
    R.set(25)

    sdfg = test_program.to_sdfg()
    sdfg.apply_strict_transformations()
    state = sdfg.nodes()[0]

    A = np.random.rand(N.get()).astype(np.float64)
    B = np.random.rand(M.get()).astype(np.float64)
    C = np.random.rand(O.get()).astype(np.float64)
    D = np.random.rand(M.get()).astype(np.float64)
    E = np.random.rand(N.get()).astype(np.float64)
    F = np.random.rand(P.get()).astype(np.float64)
    G = np.random.rand(M.get()).astype(np.float64)
    H = np.random.rand(P.get()).astype(np.float64)
    I = np.random.rand(N.get()).astype(np.float64)
    J = np.random.rand(R.get()).astype(np.float64)
    X = np.random.rand(N.get()).astype(np.float64)
    Y = np.random.rand(M.get()).astype(np.float64)
    Z = np.random.rand(P.get()).astype(np.float64)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, D=D, E=E, F=F, G=G, H=H, I=I, J=J, X=X, Y=Y, Z=Z,\
          N=N, M=M, O=O, P=P, R=R,Q=Q)

    subgraph = SubgraphView(state, [node for node in state.nodes()])
    expansion = MultiExpansion()
    fusion = SubgraphFusion()

    assert MultiExpansion.match(sdfg, subgraph)
    expansion.apply(sdfg, subgraph)

    assert SubgraphFusion.match(sdfg, subgraph)
    fusion.apply(sdfg, subgraph)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, D=D, E=E, F=F, G=G, H=H, I=I, J=J, X=X, Y=Y, Z=Z,\
          N=N, M=M, O=O, P=P, R=R,Q=Q)
    print("PASS")
Example #6
0
    def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool:
        graph = subgraph.graph
        if self.allow_expansion == True:
            subgraph_fusion = SubgraphFusion(subgraph)
            if subgraph_fusion.can_be_applied(sdfg, subgraph):
                # try w/o copy first
                return True

            expansion = MultiExpansion(subgraph)
            expansion.permutation_only = not self.expansion_split
            if expansion.can_be_applied(sdfg, subgraph):
                # deepcopy
                graph_indices = [
                    i for (i, n) in enumerate(graph.nodes()) if n in subgraph
                ]
                sdfg_copy = copy.deepcopy(sdfg)
                graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)]
                subgraph_copy = SubgraphView(
                    graph_copy, [graph_copy.nodes()[i] for i in graph_indices])
                expansion.sdfg_id = sdfg_copy.sdfg_id

                ##sdfg_copy.apply_transformations(MultiExpansion, states=[graph])
                #expansion = MultiExpansion(subgraph_copy)
                expansion.apply(sdfg_copy)

                subgraph_fusion = SubgraphFusion(subgraph_copy)
                if subgraph_fusion.can_be_applied(sdfg_copy, subgraph_copy):
                    return True

                stencil_tiling = StencilTiling(subgraph_copy)
                if self.allow_tiling and stencil_tiling.can_be_applied(
                        sdfg_copy, subgraph_copy):
                    return True

        else:
            subgraph_fusion = SubgraphFusion(subgraph)
            if subgraph_fusion.can_be_applied(sdfg, subgraph):
                return True

        if self.allow_tiling == True:
            stencil_tiling = StencilTiling(subgraph)
            if stencil_tiling.can_be_applied(sdfg, subgraph):
                return True

        return False
Example #7
0
def test_offsets_array():
    sdfg = dace.SDFG('mapfission_offsets2')
    sdfg.add_array('A', [20], dace.float64)
    sdfg.add_array('interim', [1], dace.float64, transient=True)
    state = sdfg.add_state()
    me, mx = state.add_map('outer', dict(i='10:20'))

    t1 = state.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1')
    interim = state.add_access('interim')
    t2 = state.add_tasklet('addtwo', {'a'}, {'b'}, 'b = a + 2')

    aread = state.add_read('A')
    awrite = state.add_write('A')
    state.add_memlet_path(aread, me, t1, dst_conn='a', memlet=dace.Memlet.simple('A', 'i'))
    state.add_edge(t1, 'b', interim, None, dace.Memlet.simple('interim', '0'))
    state.add_edge(interim, None, t2, 'a', dace.Memlet.simple('interim', '0'))
    state.add_memlet_path(t2, mx, awrite, src_conn='b', memlet=dace.Memlet.simple('A', 'i'))

    sdfg.apply_transformations(MapFission)

    dace.propagate_memlets_sdfg(sdfg)
    sdfg.validate()

    # Test
    A = np.random.rand(20)
    expected = A.copy()
    expected[10:] += 3
    A_cpy = A.copy()
    csdfg = sdfg.compile()
    csdfg(A=A_cpy)
    del csdfg
    print(np.linalg.norm(A_cpy))
    print(np.linalg.norm(expected))
    assert (np.allclose(A_cpy, expected))

    subgraph = SubgraphView(sdfg.nodes()[0], sdfg.nodes()[0].nodes())
    sf = SubgraphFusion()
    sf.setup_match(subgraph)
    assert sf.can_be_applied(sdfg, subgraph)
    fusion(sdfg, sdfg.nodes()[0], None)
    A_cpy = A.copy()
    csdfg = sdfg.compile()
    csdfg(A=A_cpy)
    assert (np.allclose(A_cpy, expected))
Example #8
0
 def get_actions(actions, graph, match):
     subgraph_node_ids = match.subgraph.values()
     subgraph_nodes = [graph.nodes()[nid] for nid in subgraph_node_ids]
     for node in subgraph_nodes:
         version = 0
         while (node, type(match).__name__, match.expr_index,
                version) in actions.keys():
             version += 1
         actions[(node, type(match).__name__, match.expr_index,
                  version)] = match
     subgraph = SubgraphView(graph, subgraph_nodes)
     for edge in subgraph.edges():
         version = 0
         while (edge, type(match).__name__, match.expr_index,
                version) in actions.keys():
             version += 1
         actions[(edge, type(match).__name__, match.expr_index,
                  version)] = match
     return actions
Example #9
0
def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView) -> SDFGState:
    '''
    Given a subgraph, adds a new SDFG state before the state that contains it,
    removes the subgraph from the original state, and connects the two states.
    :param subgraph: the subgraph to remove.
    :return: the newly created SDFG state.
    '''

    state: SDFGState = subgraph.graph
    newstate = sdfg.add_state_before(state)

    # Save edges before removing nodes
    orig_edges = subgraph.edges()

    # Mark boundary access nodes to keep after fission
    nodes_to_remove = set(subgraph.nodes())
    boundary_nodes = [
        n for n in subgraph.nodes()
        if len(state.out_edges(n)) > len(subgraph.out_edges(n))
    ] + [
        n for n in subgraph.nodes()
        if len(state.in_edges(n)) > len(subgraph.in_edges(n))
    ]

    # Make dictionary of nodes to add to new state
    new_nodes = {n: n for n in subgraph.nodes()}
    new_nodes.update({b: copy.deepcopy(b) for b in boundary_nodes})

    nodes_to_remove -= set(boundary_nodes)
    state.remove_nodes_from(nodes_to_remove)

    for n in new_nodes.values():
        if isinstance(n, nodes.NestedSDFG):
            # Set the new parent state
            n.sdfg.parent = newstate

    newstate.add_nodes_from(new_nodes.values())

    for e in orig_edges:
        newstate.add_edge(new_nodes[e.src], e.src_conn, new_nodes[e.dst],
                          e.dst_conn, e.data)

    return newstate
Example #10
0
def _test_quantitatively(sdfg, graph):
    A = np.random.rand(N.get(), M.get(), O.get()).astype(np.float64)
    B = np.random.rand(N.get(), M.get(), O.get()).astype(np.float64)
    C1 = np.zeros([N.get(), M.get(), O.get()], dtype=np.float64)
    C2 = np.zeros([N.get(), M.get(), O.get()], dtype=np.float64)

    sdfg.validate()
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C1, N=N, M=M, O=O)
    del csdfg

    subgraph = SubgraphView(graph, graph.nodes())
    sf = SubgraphFusion(subgraph)
    assert sf.can_be_applied(sdfg, subgraph)

    fusion(sdfg, graph)
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C2, N=N, M=M, O=O)
    del csdfg

    assert np.allclose(C1, C2)
    print('PASS')
Example #11
0
def _test_quantitatively(sdfg, graph):

    A = np.random.rand(N.get()).astype(np.float64)
    B = np.random.rand(M.get()).astype(np.float64)
    C = np.random.rand(O.get()).astype(np.float64)
    out1_base = np.ndarray((N.get(), M.get()), np.float64)
    out2_base = np.ndarray((1), np.float64)
    out3_base = np.ndarray((N.get(), M.get(), O.get()), np.float64)
    out1 = np.ndarray((N.get(), M.get()), np.float64)
    out2 = np.ndarray((1), np.float64)
    out3 = np.ndarray((N.get(), M.get(), O.get()), np.float64)
    csdfg = sdfg.compile()
    csdfg(A=A,
          B=B,
          C=C,
          out1=out1_base,
          out2=out2_base,
          out3=out3_base,
          N=N,
          M=M,
          O=O)
    del csdfg

    expand_reduce(sdfg, graph)
    expand_maps(sdfg, graph)
    subgraph = SubgraphView(graph, [node for node in graph.nodes()])
    sf = SubgraphFusion()
    sf.setup_match(subgraph)
    assert sf.can_be_applied(sdfg, subgraph) == True
    fusion(sdfg, graph)
    sdfg.validate()
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, out1=out1, out2=out2, out3=out3, N=N, M=M, O=O)
    del csdfg

    assert np.allclose(out1, out1_base)
    assert np.allclose(out2, out2_base)
    assert np.allclose(out3, out3_base)
    print('PASS')
Example #12
0
def test_quantitatively(sdfg):
    graph = sdfg.nodes()[0]
    A = np.random.rand(N.get()).astype(np.float64)
    B = np.random.rand(N.get()).astype(np.float64)
    C1 = np.random.rand(N.get()).astype(np.float64)
    C2 = np.random.rand(N.get()).astype(np.float64)
    D1 = np.random.rand(N.get()).astype(np.float64)
    D2 = np.random.rand(N.get()).astype(np.float64)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C1, D=D1, N=N)

    subgraph = SubgraphView(graph, [node for node in graph.nodes()])
    assert MultiExpansion.can_be_applied(sdfg, subgraph) == True
    MultiExpansion(subgraph).apply(sdfg)
    assert SubgraphFusion.can_be_applied(sdfg, subgraph) == True
    SubgraphFusion(subgraph).apply(sdfg)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C2, D=D2, N=N)

    assert np.allclose(C1, C2)
    assert np.allclose(D1, D2)
Example #13
0
def test_p1():
    sdfg = disjoint_test_1.to_sdfg()
    sdfg.simplify()
    state = sdfg.nodes()[0]
    assert len(sdfg.nodes()) == 1
    A = np.random.rand(M.get(), 2).astype(np.float64)
    A1 = A.copy()
    A2 = A.copy()

    csdfg = sdfg.compile()
    csdfg(A=A1, N=N, M=M)
    del csdfg

    subgraph = SubgraphView(state, state.nodes())
    sf = SubgraphFusion(subgraph)
    assert sf.can_be_applied(sdfg, subgraph)
    sf.apply(sdfg)

    csdfg = sdfg.compile()
    csdfg(A=A2, M=M)
    del csdfg

    assert np.allclose(A1, A2)
Example #14
0
def test_quantitatively(sdfg, graph):

    A = np.random.rand(N.get()).astype(np.float64)
    B = np.random.rand(M.get()).astype(np.float64)
    C = np.random.rand(O.get()).astype(np.float64)
    out1_base = np.ndarray((N.get(), M.get()), np.float64)
    out2_base = np.ndarray((1), np.float64)
    out3_base = np.ndarray((N.get(), M.get(), O.get()), np.float64)
    out1 = np.ndarray((N.get(), M.get()), np.float64)
    out2 = np.ndarray((1), np.float64)
    out3 = np.ndarray((N.get(), M.get(), O.get()), np.float64)
    csdfg = sdfg.compile()
    csdfg(A=A,
          B=B,
          C=C,
          out1=out1_base,
          out2=out2_base,
          out3=out3_base,
          N=N,
          M=M,
          O=O)

    expand_reduce(sdfg, graph)
    expand_maps(sdfg, graph)
    sgf = SubgraphFusion()
    matcher = sgf.match(sdfg,
                        SubgraphView(graph, [node for node in graph.nodes()]))
    assert matcher == True
    fusion(sdfg, graph)
    sdfg.validate()
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C, out1=out1, out2=out2, out3=out3, N=N, M=M, O=O)

    assert np.allclose(out1, out1_base)
    assert np.allclose(out2, out2_base)
    assert np.allclose(out3, out3_base)
    print('PASS')
Example #15
0
    def match(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        ### get lowest scope maps of subgraph
        # grab first node and see whether all nodes are in the same graph
        # (or nested sdfgs therein)

        graph = subgraph.graph

        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        maps = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        brng = helpers.common_map_base_ranges(maps)

        # if leq than one map found -> fail
        if len(maps) <= 1:
            return False

        # see whether they have common parameters; if not -> fail
        if len(brng) == 0:
            return False

        return True
Example #16
0
 def test_simple_sdfg_program(self):
     sdfg, state, t, me, mx = create_sdfg()
     nest_state_subgraph(sdfg, state, SubgraphView(state, state.nodes()))
     sdfg.validate()
Example #17
0
    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)
Example #18
0
 def subgraph_view(self, sdfg: SDFG) -> SubgraphView:
     graph = sdfg.sdfg_list[self.sdfg_id]
     if self.state_id != -1:
         graph = graph.node(self.state_id)
     return SubgraphView(graph, [graph.node(idx) for idx in self.subgraph])
Example #19
0
def get_transformations(sdfg_json, selected_elements):
    # We lazy import DaCe, not to break cyclic imports, but to avoid any large
    # delays when booting in daemon mode.
    from dace.transformation.optimizer import SDFGOptimizer
    from dace.sdfg.graph import SubgraphView

    old_meta = utils.disable_save_metadata()

    loaded = utils.load_sdfg_from_json(sdfg_json)
    if loaded['error'] is not None:
        return loaded['error']
    sdfg = loaded['sdfg']

    optimizer = SDFGOptimizer(sdfg)
    matches = optimizer.get_pattern_matches()

    transformations = []
    docstrings = {}
    for transformation in matches:
        transformations.append(transformation.to_json())
        docstrings[type(transformation).__name__] = transformation.__doc__

    selected_states = [
        utils.sdfg_find_state_from_element(sdfg, n) for n in selected_elements
        if n['type'] == 'state'
    ]
    selected_nodes = [
        utils.sdfg_find_node_from_element(sdfg, n) for n in selected_elements
        if n['type'] == 'node'
    ]
    selected_sdfg_ids = list(set(elem['sdfgId'] for elem in selected_elements))
    selected_sdfg = sdfg
    if len(selected_sdfg_ids) > 1:
        return {
            'transformations': transformations,
            'docstrings': docstrings,
            'warnings': 'More than one SDFG selected, ignoring subgraph',
        }
    elif len(selected_sdfg_ids) == 1:
        selected_sdfg = sdfg.sdfg_list[selected_sdfg_ids[0]]

    subgraph = None
    if len(selected_states) > 0:
        subgraph = SubgraphView(selected_sdfg, selected_states)
    else:
        violated = False
        state = None
        for node in selected_nodes:
            if state is None:
                state = node.state
            elif state != node.state:
                violated = True
                break
        if not violated and state is not None:
            subgraph = SubgraphView(state, selected_nodes)

    if subgraph is not None:
        extensions = SubgraphTransformation.extensions()
        for xform in extensions:
            xform_data = extensions[xform]
            if ('singlestate' in xform_data and xform_data['singlestate']
                    and len(selected_states) > 0):
                continue
            xform_obj = xform(subgraph)
            if xform_obj.can_be_applied(selected_sdfg, subgraph):
                transformations.append(xform_obj.to_json())
                docstrings[xform.__name__] = xform_obj.__doc__

    utils.restore_save_metadata(old_meta)
    return {
        'transformations': transformations,
        'docstrings': docstrings,
    }
Example #20
0
    def apply(self, sdfg: SDFG):
        graph = sdfg.node(self.state_id)
        map_exit = graph.node(self.subgraph[AccumulateTransient.map_exit])
        outer_map_exit = graph.node(
            self.subgraph[AccumulateTransient.outer_map_exit])

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        array_identity_dict = self.array_identity_dict

        # Choose array
        array = self.array
        if array is not None and len(array) != 0:
            array_identity_dict[array] = self.identity
        elif ((array is None or len(array) == 0)
              and len(array_identity_dict) == 0):
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)
            array_identity_dict[array] = self.identity

        transients: Dict[str, Any] = {}
        for array, identity in array_identity_dict.items():
            data_node: nodes.AccessNode = OutLocalStorage.apply_to(
                sdfg,
                dict(array=array, prefix=self.prefix),
                verify=False,
                save=False,
                node_a=map_exit,
                node_b=outer_map_exit)

            transients[data_node.data] = identity

            if identity is None:
                warnings.warn(
                    'AccumulateTransient did not properly initialize '
                    'newly-created transient!')
                return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: nodes.NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        for data_name, identity in transients.items():
            temp_array: Array = sdfg.arrays[data_name]

            init_state.add_mapped_tasklet(
                name='acctrans_init',
                map_ranges={
                    '_o%d' % i: '0:%s' % symbolic.symstr(d)
                    for i, d in enumerate(temp_array.shape)
                },
                inputs={},
                code='out = %s' % identity,
                outputs={
                    'out':
                    dace.Memlet.simple(
                        data=data_name,
                        subset_str=','.join([
                            '_o%d' % i for i, _ in enumerate(temp_array.shape)
                        ]))
                },
                external_edges=True)

        # TODO: use trivial map elimintation here when it will be merged to remove map if it has trivial ranges

        return nested_sdfg
Example #21
0
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph.graph != state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_dict(True)
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError(
                    'Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in subgraph.source_nodes():
        inputs.extend(state.in_edges(node))
    for node in subgraph.sink_nodes():
        outputs.extend(state.out_edges(node))

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes()
                           if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(
        n.data for s in sdfg.nodes() for n in s.nodes()
        if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode)
                and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = set()
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode)
                and node.data not in subgraph_transients):
            if state.out_degree(node) > 0:
                input_arrays.add(node.data)
            if state.in_degree(node) > 0:
                output_arrays.add(node.data)

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = []
    output_names = []
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = '__in_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
        input_names.append(
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = '__out_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
        output_names.append(
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    ###################

    # Add scope symbols to the nested SDFG
    for v in scope.defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for name, edge in zip(input_names, inputs):
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge,
                                nstate.add_edge(node, None, edge.dst,
                                                edge.dst_conn, new_edge)))
    for name, edge in zip(output_names, outputs):
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge,
                                nstate.add_edge(edge.src, edge.src_conn, node,
                                                None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(original_edge.data.subset, True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names) | input_arrays,
                                        set(output_names) | output_arrays)

    # Reconnect memlets to nested SDFG
    for name, edge in zip(input_names, inputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
                                     sdfg.arrays[edge.data.data])
        else:
            data = edge.data
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
    for name, edge in zip(output_names, outputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
                                     sdfg.arrays[edge.data.data])
        else:
            data = edge.data
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, EmptyMemlet())
        state.add_edge(node, None, nested_sdfg, name,
                       Memlet.from_array(name, sdfg.arrays[name]))
    for name in output_arrays:
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, EmptyMemlet())
        state.add_edge(nested_sdfg, name, node, None,
                       Memlet.from_array(name, sdfg.arrays[name]))

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    return nested_sdfg
Example #22
0
    def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        '''
        Fusible if
        1. Maps have the same access sets and ranges in order
        2. Any nodes in between two maps are AccessNodes only, without WCR
           There is at most one AccessNode only on a path between two maps,
           no other nodes are allowed
        3. The exiting memlets' subsets to an intermediate edge must cover
           the respective incoming memlets' subset into the next map
        '''
        # get graph
        graph = subgraph.graph
        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        map_entries = helpers.get_highest_scope_maps(sdfg, graph, subgraph)
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]
        maps = [map_entry.map for map_entry in map_entries]

        # 1. check whether all map ranges and indices are the same
        if len(maps) <= 1:
            return False
        base_map = maps[0]
        for map in maps:
            if map.get_param_num() != base_map.get_param_num():
                return False
            if not all(
                [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]):
                return False
            if not map.range == base_map.range:
                return False

        # 1.1 check whether all map entries have the same schedule
        schedule = map_entries[0].schedule
        if not all([entry.schedule == schedule for entry in map_entries]):
            return False

        # 2. check intermediate feasiblility
        # see map_fusion.py for similar checks
        # we are being more relaxed here

        # 2.1 do some preparation work first:
        # calculate all out_nodes and intermediate_nodes
        # definition see in apply()
        intermediate_nodes = set()
        out_nodes = set()
        for map_entry, map_exit in zip(map_entries, map_exits):
            for edge in graph.out_edges(map_exit):
                current_node = edge.dst
                if len(graph.out_edges(current_node)) == 0:
                    out_nodes.add(current_node)
                else:
                    for dst_edge in graph.out_edges(current_node):
                        if dst_edge.dst in map_entries:
                            intermediate_nodes.add(current_node)
                        else:
                            out_nodes.add(current_node)

        # 2.2 topological feasibility:
        # For each intermediate and out node: must never reach any map
        # entry if it is not connected to map entry immediately
        visited = set()

        # for memoization purposes
        def visit_descendants(graph, node, visited, map_entries):
            # if we have already been at this node
            if node in visited:
                return True
            # not necessary to add if there aren't any other in connections
            if len(graph.in_edges(node)) > 1:
                visited.add(node)
            for oedge in graph.out_edges(node):
                if not visit_descendants(graph, oedge.dst, visited,
                                         map_entries):
                    return False
            return True

        for node in intermediate_nodes | out_nodes:
            # these nodes must not lead to a map entry
            nodes_to_check = set()
            for oedge in graph.out_edges(node):
                if oedge.dst not in map_entries:
                    nodes_to_check.add(oedge.dst)

            for forbidden_node in nodes_to_check:
                if not visit_descendants(graph, forbidden_node, visited,
                                         map_entries):
                    return False

        # 2.3 memlet feasibility
        # For each intermediate node, look at whether inner adjacent
        # memlets of the exiting map cover inner adjacent memlets
        # of the next entering map.
        # We also check for any WCRs on the fly.

        for node in intermediate_nodes:
            upper_subsets = set()
            lower_subsets = set()
            # First, determine which dimensions of the memlet ranges
            # change with the map, we do not need to care about the other dimensions.
            total_dims = len(sdfg.data(node.data).shape)
            dims_to_discard = SubgraphFusion.get_invariant_dimensions(
                sdfg, graph, map_entries, map_exits, node)

            # find upper_subsets
            for in_edge in graph.in_edges(node):
                # first check for WCRs
                if in_edge.data.wcr:
                    return False
                if in_edge.src in map_exits:
                    edge = graph.memlet_path(in_edge)[-2]
                    subset_to_add = dcpy(edge.data.subset\
                                         if edge.data.data == node.data\
                                         else edge.data.other_subset)
                    subset_to_add.pop(dims_to_discard)
                    upper_subsets.add(subset_to_add)
                else:
                    raise NotImplementedError("Nodes between two maps to be"
                                              "fused with *incoming* edges"
                                              "from outside the maps are not"
                                              "allowed yet.")

            # find lower_subsets
            for out_edge in graph.out_edges(node):
                if out_edge.dst in map_entries:
                    # cannot use memlet tree here as there could be
                    # not just one map succedding. Do it manually
                    for oedge in graph.out_edges(out_edge.dst):
                        if oedge.src_conn[3:] == out_edge.dst_conn[2:]:
                            subset_to_add = dcpy(oedge.data.subset \
                                                 if edge.data.data == node.data \
                                                 else edge.data.other_subset)
                            subset_to_add.pop(dims_to_discard)
                            lower_subsets.add(subset_to_add)

            upper_iter = iter(upper_subsets)
            union_upper = next(upper_iter)

            # TODO: add this check at a later point
            # We assume that upper_subsets for each data array
            # are contiguous
            # or do the full check if possible (intersection needed)
            '''
            # check whether subsets in upper_subsets are adjacent.
            # this is a requriement for the current implementation
            #try:
            # O(n^2*|dims|) but very small amount of subsets anyway
            try:
                for dim in range(total_dims - len(dims_to_discard)):
                    ordered_list = [(-1,-1,-1)]
                    for upper_subset in upper_subsets:
                        lo = upper_subset[dim][0]
                        hi = upper_subset[dim][1]
                        for idx,element in enumerate(ordered_list):
                            if element[0] <= lo and element[1] >= hi:
                                break
                            if element[0] > lo:
                                ordered_list.insert(idx, (lo,hi))
                    ordered_list.pop(0)


                    highest = ordered_list[0][1]
                    for i in range(len(ordered_list)):
                        if i < len(ordered_list)-1:
                            current_range = ordered_list[i]
                            if current_range[1] > highest:
                                hightest = current_range[1]
                            next_range = ordered_list[i+1]
                            if highest < next_range[0] - 1:
                                return False
            except TypeError:
                #return False
            '''
            # FORNOW: just omit warning if unsure
            for lower_subset in lower_subsets:
                covers = False
                for upper_subset in upper_subsets:
                    if upper_subset.covers(lower_subset):
                        covers = True
                        break
                if not covers:
                    warnings.warn(
                        f"WARNING: For node {node}, please check assure that"
                        "incoming memlets cover outgoing ones. Ambiguous check (WIP)."
                    )

            # now take union of upper subsets
            for subs in upper_iter:
                union_upper = subsets.union(union_upper, subs)
                if not union_upper:
                    # something went wrong using union -- we'd rather abort
                    return False

            # finally check coverage
            for lower_subset in lower_subsets:
                if not union_upper.covers(lower_subset):
                    return False

        return True
Example #23
0
    def apply(self, sdfg: SDFG):
        subgraph = self.subgraph_view(sdfg)

        entry_states_in, entry_states_out = self.get_entry_states(
            sdfg, subgraph)
        _, exit_states_out = self.get_exit_states(sdfg, subgraph)

        entry_state_in = entry_states_in.pop()
        entry_state_out = entry_states_out.pop() \
            if len(entry_states_out) > 0 else None
        exit_state_out = exit_states_out.pop() \
            if len(exit_states_out) > 0 else None

        launch_state = None
        entry_guard_state = None
        exit_guard_state = None

        # generate entry guard state if needed
        if self.include_in_assignment and entry_state_out is not None:
            entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0]
            if len(entry_edge.data.assignments) > 0:
                entry_guard_state = sdfg.add_state(
                    label='{}kernel_entry_guard'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))
                sdfg.add_edge(entry_state_out, entry_guard_state,
                              InterstateEdge(entry_edge.data.condition))
                sdfg.add_edge(
                    entry_guard_state, entry_state_in,
                    InterstateEdge(None, entry_edge.data.assignments))
                sdfg.remove_edge(entry_edge)

                # Update SubgraphView
                new_node_list = subgraph.nodes()
                new_node_list.append(entry_guard_state)
                subgraph = SubgraphView(sdfg, new_node_list)

                launch_state = sdfg.add_state_before(
                    entry_guard_state,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # generate exit guard state
        if exit_state_out is not None:
            exit_guard_state = sdfg.add_state_before(
                exit_state_out,
                label='{}kernel_exit_guard'.format(
                    self.kernel_prefix +
                    '_' if self.kernel_prefix != '' else ''))

            # Update SubgraphView
            new_node_list = subgraph.nodes()
            new_node_list.append(exit_guard_state)
            subgraph = SubgraphView(sdfg, new_node_list)

            if launch_state is None:
                launch_state = sdfg.add_state_before(
                    exit_state_out,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # If the launch state doesn't exist at this point then there is no other
        # states outside of the kernel, so create a stand alone launch state
        if launch_state is None:
            assert (entry_state_in is None and exit_state_out is None)
            launch_state = sdfg.add_state(label='{}kernel_launch'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''))

        # create sdfg for kernel and fill it with states and edges from
        # ssubgraph dfg will be nested at the end
        kernel_sdfg = SDFG(
            '{}kernel'.format(self.kernel_prefix +
                              '_' if self.kernel_prefix != '' else ''))

        edges = subgraph.edges()
        for edge in edges:
            kernel_sdfg.add_edge(edge.src, edge.dst, edge.data)

        # Setting entry node in nested SDFG if no entry guard was created
        if entry_guard_state is None:
            kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in)

        for state in subgraph:
            state.parent = kernel_sdfg

        # remove the now nested nodes from the outer sdfg and make sure the
        # launch state is properly connected to remaining states
        sdfg.remove_nodes_from(subgraph.nodes())

        if entry_state_out is not None \
                and len(sdfg.edges_between(entry_state_out, launch_state)) == 0:
            sdfg.add_edge(entry_state_out, launch_state, InterstateEdge())

        if exit_state_out is not None \
                and len(sdfg.edges_between(launch_state, exit_state_out)) == 0:
            sdfg.add_edge(launch_state, exit_state_out, InterstateEdge())

        # Handle data for kernel
        kernel_data = set(node.data for state in kernel_sdfg
                          for node in state.nodes()
                          if isinstance(node, nodes.AccessNode))

        # move Streams and Register data into the nested SDFG
        # normal data will be added as kernel argument
        kernel_args = []
        for data in kernel_data:
            if (isinstance(sdfg.arrays[data], dace.data.Stream) or
                (isinstance(sdfg.arrays[data], dace.data.Array)
                 and sdfg.arrays[data].storage == StorageType.Register)):
                kernel_sdfg.add_datadesc(data, sdfg.arrays[data])
                del sdfg.arrays[data]
            else:
                copy_desc = copy.deepcopy(sdfg.arrays[data])
                copy_desc.transient = False
                copy_desc.storage = StorageType.Default
                kernel_sdfg.add_datadesc(data, copy_desc)
                kernel_args.append(data)

        # read only data will be passed as input, writeable data will be passed
        # as 'output' otherwise kernel cannot write to data
        kernel_args_read = set()
        kernel_args_write = set()
        for data in kernel_args:
            data_accesses_read_only = [
                node.access == dtypes.AccessType.ReadOnly
                for state in kernel_sdfg for node in state
                if isinstance(node, nodes.AccessNode) and node.data == data
            ]
            if all(data_accesses_read_only):
                kernel_args_read.add(data)
            else:
                kernel_args_write.add(data)

        # Kernel SDFG is complete at this point
        if self.validate:
            kernel_sdfg.validate()

        # Filling launch state with nested SDFG, map and access nodes
        map_entry, map_exit = launch_state.add_map(
            '{}kernel_launch_map'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''),
            dict(ignore='0'),
            schedule=ScheduleType.GPU_Persistent,
        )

        nested_sdfg = launch_state.add_nested_sdfg(
            kernel_sdfg,
            sdfg,
            kernel_args_read,
            kernel_args_write,
        )

        # Create and connect read only data access nodes
        for arg in kernel_args_read:
            read_node = launch_state.add_read(arg)
            launch_state.add_memlet_path(read_node,
                                         map_entry,
                                         nested_sdfg,
                                         dst_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Create and connect writable data access nodes
        for arg in kernel_args_write:
            write_node = launch_state.add_write(arg)
            launch_state.add_memlet_path(nested_sdfg,
                                         map_exit,
                                         write_node,
                                         src_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Transformation is done
        if self.validate:
            sdfg.validate()
Example #24
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        map_exit = self.map_exit
        outer_map_exit = self.outer_map_exit

        # Choose array
        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        data_node: nodes.AccessNode = OutLocalStorage.apply_to(
            sdfg,
            dict(array=array),
            verify=False,
            save=False,
            node_a=map_exit,
            node_b=outer_map_exit)

        if self.identity is None:
            warnings.warn('AccumulateTransient did not properly initialize '
                          'newly-created transient!')
            return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        temp_array: Array = sdfg.arrays[data_node.data]

        init_state.add_mapped_tasklet(
            name='acctrans_init',
            map_ranges={
                '_o%d' % i: '0:%s' % symstr(d)
                for i, d in enumerate(temp_array.shape)
            },
            inputs={},
            code='out = %s' % self.identity,
            outputs={
                'out':
                dace.Memlet.simple(data=data_node.data,
                                   subset_str=','.join([
                                       '_o%d' % i
                                       for i, _ in enumerate(temp_array.shape)
                                   ]))
            },
            external_edges=True)
Example #25
0
    def calculate_topology(self, subgraph):
        ''' Calculates topology information of the graph 
        self._adjacency_list: neighbors dict of outermost scope maps  
        self._source_maps: outermost scope maps that have in_degree 0 in the subgraph / graph 
        self._labels: assigns index according to topological ordering (1) + node ID (2) with priorities (1) and (2)
        '''
        sdfg = self._sdfg
        graph = self._graph

        self._adjacency_list = {m: set() for m in self._map_entries}
        # helper dict needed for a quick build
        exit_nodes = {graph.exit_node(me): me for me in self._map_entries}
        if subgraph:
            proximity_in = set(ie.src for me in self._map_entries
                               for ie in graph.in_edges(me))
            proximity_out = set(ie.dst for me in exit_nodes
                                for ie in graph.out_edges(me))
            extended_subgraph = SubgraphView(
                graph,
                set(
                    itertools.chain(subgraph.nodes(), proximity_in,
                                    proximity_out)))

        for node in (extended_subgraph.nodes() if subgraph else graph.nodes()):
            if isinstance(node, nodes.AccessNode):
                adjacent_entries = set()
                for e in graph.in_edges(node):
                    if isinstance(e.src,
                                  nodes.MapExit) and e.src in exit_nodes:
                        adjacent_entries.add(exit_nodes[e.src])
                for e in graph.out_edges(node):
                    if isinstance(
                            e.dst,
                            nodes.MapEntry) and e.dst in self._map_entries:
                        adjacent_entries.add(e.dst)

                # bidirectional mapping
                for entry in adjacent_entries:
                    for other_entry in adjacent_entries:
                        if entry != other_entry:
                            self._adjacency_list[entry].add(other_entry)
                            self._adjacency_list[other_entry].add(entry)

        # get DAG children and parents
        children_dict = defaultdict(set)
        parent_dict = defaultdict(set)

        for map_entry in self._map_entries:
            map_exit = graph.exit_node(map_entry)
            for e in graph.out_edges(map_exit):
                if isinstance(e.dst, nodes.AccessNode):
                    for oe in graph.out_edges(e.dst):
                        if oe.dst in self._map_entries:
                            other_entry = oe.dst
                            children_dict[map_entry].add(other_entry)
                            parent_dict[other_entry].add(map_entry)

        # find out source nodes
        self._source_maps = [
            me for me in self._map_entries if len(parent_dict[me]) == 0
        ]
        # assign a unique id to each map entry according to topological
        # ordering. If on same level, sort according to ID for determinism

        self._labels = {}  # map -> ID
        current_id = 0
        while current_id < len(self._map_entries):
            # get current ids whose in_degree is 0
            candidates = list(me for (me, s) in parent_dict.items()
                              if len(s) == 0 and me not in self._labels)
            candidates.sort(key=lambda me: self._graph.node_id(me))
            for c in candidates:
                self._labels[c] = current_id
                current_id += 1
                # remove candidate for each players adjacency list
                for c_child in children_dict[c]:
                    parent_dict[c_child].remove(c)
Example #26
0
def greedy_fuse(graph_or_subgraph: GraphViewType,
                validate_all: bool,
                device: dace.dtypes.DeviceType = dace.dtypes.DeviceType.CPU,
                recursive: bool = True,
                stencil: bool = False,
                stencil_tile=None,
                permutations_only: bool = True,
                expand_reductions: bool = False) -> None:
    '''
    Greedily fuses maps of an SDFG or graph, operating in-place.
    :param graph_or_subgraph: SDFG, SDFGState or Subgraph
    :param validate_all: Validate SDFG or graph at each fusion step 
    :param device: Device type to specialize for 
    :param recursive: Fuse recursively within (fused and unfused) scopes
    :param stencil: Perform stencil fusion instead of regular fusion 
    :param stencil_tile: StencilTiling Tile size, default if None
    :param permutations_only: Disallow splitting of maps during MultiExpansion stage
    :param expand_reductions: Expand all reduce nodes before fusion
    '''
    debugprint = config.Config.get_bool('debugprint')
    if isinstance(graph_or_subgraph, SDFG):
        # If we have an SDFG, recurse into graphs
        graph_or_subgraph.simplify(validate_all=validate_all)
        # MapFusion for trivial cases
        graph_or_subgraph.apply_transformations_repeated(
            MapFusion, validate_all=validate_all)
        # recurse into graphs
        for graph in graph_or_subgraph.nodes():

            greedy_fuse(graph,
                        validate_all=validate_all,
                        device=device,
                        recursive=recursive,
                        stencil=stencil,
                        stencil_tile=stencil_tile,
                        permutations_only=permutations_only,
                        expand_reductions=expand_reductions)
    else:
        # we are in graph or subgraph
        sdfg, graph, subgraph = None, None, None
        if isinstance(graph_or_subgraph, SDFGState):
            sdfg = graph_or_subgraph.parent
            sdfg.apply_transformations_repeated(MapFusion,
                                                validate_all=validate_all)
            graph = graph_or_subgraph
            subgraph = SubgraphView(graph, graph.nodes())
        else:
            sdfg = graph_or_subgraph.graph.parent
            graph = graph_or_subgraph.graph
            subgraph = graph_or_subgraph

        # create condition function object
        fusion_condition = CompositeFusion(SubgraphView(graph, graph.nodes()))

        # within SDFGState: greedily enumerate fusible components
        # and apply transformation
        applied_transformations = 0
        reverse = True if stencil else False

        if stencil:
            # adjust tiling settings
            fusion_condition.allow_tiling = True
            fusion_condition.schedule_innermaps = dtypes.ScheduleType.Sequential
            if device == dtypes.DeviceType.GPU:
                fusion_condition.stencil_unroll_loops = True
            # tile size
            if stencil_tile:
                fusion_condition.stencil_strides = stencil_tile
            # always only permutate for now with stencil tiles
            fusion_condition.expansion_split = False

        else:
            fusion_condition.allow_tiling = False
            # expand reductions
            if expand_reductions:
                for graph in sdfg.nodes():
                    for node in graph.nodes():
                        if isinstance(node,
                                      dace.libraries.standard.nodes.Reduce):
                            try:
                                ReduceExpansion.apply_to(sdfg, reduce=node)
                            except ValueError as e:
                                pass
            # permutation settings
            fusion_condition.expansion_split = not permutations_only

        condition_function = lambda sdfg, subgraph: fusion_condition.can_be_applied(
            sdfg, subgraph)
        enumerator = GreedyEnumerator(sdfg,
                                      graph,
                                      subgraph,
                                      condition_function=condition_function)
        for map_entries in enumerator:
            if len(map_entries) > 1:
                current_subgraph = xfsh.subgraph_from_maps(
                    sdfg, graph, map_entries)
                cf = CompositeFusion(current_subgraph)
                # transfer settings
                cf.allow_tiling = fusion_condition.allow_tiling
                cf.schedule_innermaps = fusion_condition.schedule_innermaps
                cf.expansion_split = fusion_condition.expansion_split
                cf.stencil_strides = fusion_condition.stencil_strides

                cf.apply(sdfg)
                applied_transformations += 1

            if recursive:
                global_entry = cf._global_map_entry if len(
                    map_entries) > 1 else map_entries[0]

                greedy_fuse(graph.scope_subgraph(global_entry,
                                                 include_entry=False,
                                                 include_exit=False),
                            validate_all=validate_all,
                            device=device,
                            recursive=recursive,
                            stencil=stencil,
                            stencil_tile=stencil_tile,
                            permutations_only=permutations_only,
                            expand_reductions=expand_reductions)

        for node in graph_or_subgraph.nodes():
            if isinstance(node, nodes.NestedSDFG):
                greedy_fuse(node.sdfg,
                            validate_all=validate_all,
                            device=device,
                            stencil=stencil,
                            stencil_tile=stencil_tile,
                            recursive=recursive,
                            permutations_only=permutations_only,
                            expand_reductions=expand_reductions)

        if applied_transformations > 0:
            if debugprint:
                if stencil:
                    print(f"Applied {applied_transformations} TileFusion")
                else:
                    print(f"Applied {applied_transformations} SubgraphFusion")

        if validate_all:
            graph.validate()
Example #27
0
 def test_simple_sdfg_map(self):
     sdfg, state, t, me, mx = create_sdfg()
     nest_state_subgraph(sdfg, state, SubgraphView(state, [me, t, mx]))
     sdfg.validate()
Example #28
0
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
                inputs.append(edge)
        for edge in state.out_edges(node):
            if edge.dst not in snodes:
                outputs.append(edge)

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
                input_arrays.add(node.data)
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name
    ###################

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge
            continue

        name = input_names[edge]
        if name in reconnected_in:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
        reconnected_in.add(name)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge
            continue

        name = output_names[edge]
        if name in reconnected_out:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)
        reconnected_out.add(name)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
            state.remove_node(edge.src)
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:
            state.remove_node(edge.dst)

    return nested_sdfg
Example #29
0
    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
                continue
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):
                reshapes.add(aname)

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                                                 datadesc,
                                                 find_new_name=True)
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
            else:
                newname = dname
            newname, _ = sdfg.add_view(newname,
                                       desc.shape,
                                       desc.dtype,
                                       storage=desc.storage,
                                       strides=desc.strides,
                                       offset=desc.offset,
                                       debuginfo=desc.debuginfo,
                                       allow_conflicts=desc.allow_conflicts,
                                       total_size=desc.total_size,
                                       alignment=desc.alignment,
                                       may_alias=desc.may_alias,
                                       find_new_name=True)
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.remove_edge(e)
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,
                                        e.data)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
                                  True)
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
                                  False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)
Example #30
0
    def can_be_applied(sdfg: SDFG, subgraph: SubgraphView) -> bool:
        '''
        Fusible if
        1. Maps have the same access sets and ranges in order
        2. Any nodes in between two maps are AccessNodes only, without WCR
           There is at most one AccessNode only on a path between two maps,
           no other nodes are allowed
        3. The exiting memlets' subsets to an intermediate edge must cover
           the respective incoming memlets' subset into the next map.
           Also, as a limitation, the union of all exiting memlets'
           subsets must be contiguous.
        '''
        # get graph
        graph = subgraph.graph
        for node in subgraph.nodes():
            if node not in graph.nodes():
                return False

        # next, get all the maps
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph)
        map_exits = [graph.exit_node(map_entry) for map_entry in map_entries]
        maps = [map_entry.map for map_entry in map_entries]

        # 1. basic checks:
        # 1.1 we need to have at least two maps
        if len(maps) <= 1:
            return False
        '''
        # 1.2 Special Case: If we can establish a valid permutation, we can
        #     skip check 1.3
        permutation = self.find_permutation
        '''
        # 1.3 check whether all maps are the same
        base_map = maps[0]
        for map in maps:
            if map.get_param_num() != base_map.get_param_num():
                return False
            if not all(
                [p1 == p2 for (p1, p2) in zip(map.params, base_map.params)]):
                return False
            if not map.range == base_map.range:
                return False
        # 1.3 check whether all map entries have the same schedule
        schedule = map_entries[0].schedule
        if not all([entry.schedule == schedule for entry in map_entries]):
            return False

        # 2. check intermediate feasiblility
        # see map_fusion.py for similar checks
        # with the restrictions below being more relaxed

        # 2.1 do some preparation work first:
        # calculate all out_nodes and intermediate_nodes
        # definition see in apply()
        node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph,
                                                        map_entries)
        _, intermediate_nodes, out_nodes = node_config

        # 2.2 topological feasibility:
        if not SubgraphFusion.check_topo_feasibility(
                sdfg, graph, map_entries, intermediate_nodes, out_nodes):
            return False

        # 2.3 memlet feasibility
        # For each intermediate node, look at whether inner adjacent
        # memlets of the exiting map cover inner adjacent memlets
        # of the next entering map.
        # We also check for any WCRs on the fly.

        for node in intermediate_nodes:
            upper_subsets = set()
            lower_subsets = set()
            # First, determine which dimensions of the memlet ranges
            # change with the map, we do not need to care about the other dimensions.
            try:
                dims_to_discard = SubgraphFusion.get_invariant_dimensions(
                    sdfg, graph, map_entries, map_exits, node)
            except NotImplementedError:
                return False
            # find upper_subsets
            for in_edge in graph.in_edges(node):
                in_in_edge = graph.memlet_path(in_edge)[-2]
                # first check for WCRs
                if in_edge.data.wcr:
                    # check whether the WCR is actually produced at
                    # this edge or further up in the memlet path. If not,
                    # we can still fuse!
                    subset_params = set(
                        [str(s) for s in in_in_edge.data.subset.free_symbols])
                    if any([
                            p not in subset_params
                            for p in in_edge.src.map.params
                    ]):
                        return False
                if in_edge.src in map_exits:
                    subset_to_add = dcpy(in_in_edge.data.subset\
                                         if in_in_edge.data.data == node.data\
                                         else in_in_edge.data.other_subset)
                    subset_to_add.pop(dims_to_discard)
                    upper_subsets.add(subset_to_add)
                else:
                    raise NotImplementedError("Nodes between two maps to be"
                                              "fused with *incoming* edges"
                                              "from outside the maps are not"
                                              "allowed yet.")

            # find lower_subsets
            for out_edge in graph.out_edges(node):
                if out_edge.dst in map_entries:
                    # cannot use memlet tree here as there could be
                    # not just one map succedding. Do it manually
                    for oedge in graph.out_edges(out_edge.dst):
                        if oedge.src_conn[3:] == out_edge.dst_conn[2:]:
                            subset_to_add = dcpy(oedge.data.subset \
                                                 if oedge.data.data == node.data \
                                                 else oedge.data.other_subset)
                            subset_to_add.pop(dims_to_discard)
                            lower_subsets.add(subset_to_add)

            # We assume that upper_subsets are contiguous
            # Check for this.
            try:
                contiguous_upper = find_contiguous_subsets(upper_subsets)
                if len(contiguous_upper) > 1:
                    return False
            except TypeError:
                warnings.warn(
                    'Could not determine whether subset is continuous.'
                    'Exiting Check with False.')
                return False

            # now take union of upper subsets
            upper_iter = iter(upper_subsets)
            union_upper = next(upper_iter)
            for subs in upper_iter:
                union_upper = subsets.union(union_upper, subs)
                if not union_upper:
                    # something went wrong using union -- we'd rather abort
                    return False

            # finally check coverage
            # every lower subset must be completely covered by union_upper
            for lower_subset in lower_subsets:
                if not union_upper.covers(lower_subset):
                    return False

        return True