예제 #1
0
def expand_reduce(sdfg: dace.SDFG,
                  graph: dace.SDFGState,
                  subgraph: Union[SubgraphView, List[SubgraphView]] = None,
                  **kwargs):

    subgraph = graph if not subgraph else subgraph
    if not isinstance(subgraph, list):
        subgraph = [subgraph]

    for sg in subgraph:
        reduce_nodes = []
        for node in sg.nodes():
            if isinstance(node, stdlib.Reduce):
                if not ReduceExpansion.can_be_applied(
                        graph=graph,
                        candidate={
                            ReduceExpansion._reduce: graph.node_id(node)
                        },
                        expr_index=0,
                        sdfg=sdfg):
                    print(f"WARNING: Cannot expand reduce node {node}:"
                          "can_be_applied() failed.")
                    continue
                reduce_nodes.append(node)

        trafo_reduce = ReduceExpansion(0, 0, {}, 0)
        for (property, val) in kwargs.items():
            setattr(trafo_reduce, property, val)

        for reduce_node in reduce_nodes:
            trafo_reduce.expand(sdfg, graph, reduce_node)
            if isinstance(sg, SubgraphView):
                sg.nodes().remove(reduce_node)
                sg.nodes().append(trafo_reduce._new_reduce)
                sg.nodes().append(trafo_reduce._outer_entry)
예제 #2
0
def test_blockallreduce():
    A = np.random.rand(M.get(), N.get()).astype(np.float32)
    sdfg = test_program.to_sdfg()
    sdfg.apply_gpu_transformations()

    graph = sdfg.nodes()[0]
    for node in graph.nodes():
        if isinstance(node, Reduce):
            reduce_node = node
    reduce_node.implementation = 'CUDA (device)'

    csdfg = sdfg.compile()
    result1 = csdfg(A=A, M=M, N=N)

    sdfg_id = 0
    state_id = 0
    subgraph = {ReduceExpansion._reduce: graph.nodes().index(reduce_node)}
    # expand first
    transform = ReduceExpansion(sdfg_id, state_id, subgraph, 0)
    transform.reduce_implementation = 'CUDA (block allreduce)'
    transform.apply(sdfg)
    csdfg = sdfg.compile()
    result2 = csdfg(A=A, M=M, N=N)

    print(np.linalg.norm(result1))
    print(np.linalg.norm(result2))
    assert np.allclose(result1, result2)

    print("PASS")
예제 #3
0
def test_p1(in_transient, out_transient):
    sdfg = reduction_test_1.to_sdfg()
    sdfg.simplify()
    state = sdfg.nodes()[0]
    for node in state.nodes():
        if isinstance(node, dace.libraries.standard.nodes.Reduce):
            reduce_node = node

    rexp = ReduceExpansion(
        sdfg, sdfg.sdfg_id, 0,
        {ReduceExpansion.reduce: state.node_id(reduce_node)}, 0)
    assert rexp.can_be_applied(state, 0, sdfg) == True

    A = np.random.rand(M.get(), N.get()).astype(np.float64)
    B = np.random.rand(M.get(), N.get()).astype(np.float64)
    C1 = np.zeros([N.get()], dtype=np.float64)
    C2 = np.zeros([N.get()], dtype=np.float64)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C1, N=N, M=M)
    del csdfg

    expand_reduce(sdfg,
                  state,
                  create_in_transient=in_transient,
                  create_out_transient=out_transient)
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C2, N=N, M=M)
    del csdfg

    assert np.linalg.norm(C1) > 0.01
    assert np.allclose(C1, C2)
예제 #4
0
def test_p1():
    sdfg = program.to_sdfg()
    sdfg.apply_strict_transformations()
    state = sdfg.nodes()[0]
    for node in state.nodes():
        if isinstance(node, dace.libraries.standard.nodes.Reduce):
            reduce_node = node

    assert ReduceExpansion.can_be_applied(state, \
                                          {ReduceExpansion._reduce: state.nodes().index(reduce_node)}, \
                                          0, \
                                          sdfg) == True

    A = np.random.rand(M.get(), N.get()).astype(np.float64)
    B = np.random.rand(M.get(), N.get()).astype(np.float64)
    C1 = np.zeros([N.get()], dtype=np.float64)
    C2 = np.zeros([N.get()], dtype=np.float64)

    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C1, N=N, M=M)
    del csdfg

    expand_reduce(sdfg, state)
    csdfg = sdfg.compile()
    csdfg(A=A, B=B, C=C2, N=N, M=M)
    del csdfg

    assert np.allclose(C1, C2)
    print(np.linalg.norm(C1))
    print(np.linalg.norm(C2))
    print("PASS")
예제 #5
0
def greedy_fuse(graph_or_subgraph: GraphViewType,
                validate_all: bool,
                device: dace.dtypes.DeviceType = dace.dtypes.DeviceType.CPU,
                recursive: bool = True,
                stencil: bool = False,
                stencil_tile=None,
                permutations_only: bool = True,
                expand_reductions: bool = False) -> None:
    '''
    Greedily fuses maps of an SDFG or graph, operating in-place.
    :param graph_or_subgraph: SDFG, SDFGState or Subgraph
    :param validate_all: Validate SDFG or graph at each fusion step 
    :param device: Device type to specialize for 
    :param recursive: Fuse recursively within (fused and unfused) scopes
    :param stencil: Perform stencil fusion instead of regular fusion 
    :param stencil_tile: StencilTiling Tile size, default if None
    :param permutations_only: Disallow splitting of maps during MultiExpansion stage
    :param expand_reductions: Expand all reduce nodes before fusion
    '''
    debugprint = config.Config.get_bool('debugprint')
    if isinstance(graph_or_subgraph, SDFG):
        # If we have an SDFG, recurse into graphs
        graph_or_subgraph.simplify(validate_all=validate_all)
        # MapFusion for trivial cases
        graph_or_subgraph.apply_transformations_repeated(
            MapFusion, validate_all=validate_all)
        # recurse into graphs
        for graph in graph_or_subgraph.nodes():

            greedy_fuse(graph,
                        validate_all=validate_all,
                        device=device,
                        recursive=recursive,
                        stencil=stencil,
                        stencil_tile=stencil_tile,
                        permutations_only=permutations_only,
                        expand_reductions=expand_reductions)
    else:
        # we are in graph or subgraph
        sdfg, graph, subgraph = None, None, None
        if isinstance(graph_or_subgraph, SDFGState):
            sdfg = graph_or_subgraph.parent
            sdfg.apply_transformations_repeated(MapFusion,
                                                validate_all=validate_all)
            graph = graph_or_subgraph
            subgraph = SubgraphView(graph, graph.nodes())
        else:
            sdfg = graph_or_subgraph.graph.parent
            graph = graph_or_subgraph.graph
            subgraph = graph_or_subgraph

        # create condition function object
        fusion_condition = CompositeFusion(SubgraphView(graph, graph.nodes()))

        # within SDFGState: greedily enumerate fusible components
        # and apply transformation
        applied_transformations = 0
        reverse = True if stencil else False

        if stencil:
            # adjust tiling settings
            fusion_condition.allow_tiling = True
            fusion_condition.schedule_innermaps = dtypes.ScheduleType.Sequential
            if device == dtypes.DeviceType.GPU:
                fusion_condition.stencil_unroll_loops = True
            # tile size
            if stencil_tile:
                fusion_condition.stencil_strides = stencil_tile
            # always only permutate for now with stencil tiles
            fusion_condition.expansion_split = False

        else:
            fusion_condition.allow_tiling = False
            # expand reductions
            if expand_reductions:
                for graph in sdfg.nodes():
                    for node in graph.nodes():
                        if isinstance(node,
                                      dace.libraries.standard.nodes.Reduce):
                            try:
                                ReduceExpansion.apply_to(sdfg, reduce=node)
                            except ValueError as e:
                                pass
            # permutation settings
            fusion_condition.expansion_split = not permutations_only

        condition_function = lambda sdfg, subgraph: fusion_condition.can_be_applied(
            sdfg, subgraph)
        enumerator = GreedyEnumerator(sdfg,
                                      graph,
                                      subgraph,
                                      condition_function=condition_function)
        for map_entries in enumerator:
            if len(map_entries) > 1:
                current_subgraph = xfsh.subgraph_from_maps(
                    sdfg, graph, map_entries)
                cf = CompositeFusion(current_subgraph)
                # transfer settings
                cf.allow_tiling = fusion_condition.allow_tiling
                cf.schedule_innermaps = fusion_condition.schedule_innermaps
                cf.expansion_split = fusion_condition.expansion_split
                cf.stencil_strides = fusion_condition.stencil_strides

                cf.apply(sdfg)
                applied_transformations += 1

            if recursive:
                global_entry = cf._global_map_entry if len(
                    map_entries) > 1 else map_entries[0]

                greedy_fuse(graph.scope_subgraph(global_entry,
                                                 include_entry=False,
                                                 include_exit=False),
                            validate_all=validate_all,
                            device=device,
                            recursive=recursive,
                            stencil=stencil,
                            stencil_tile=stencil_tile,
                            permutations_only=permutations_only,
                            expand_reductions=expand_reductions)

        for node in graph_or_subgraph.nodes():
            if isinstance(node, nodes.NestedSDFG):
                greedy_fuse(node.sdfg,
                            validate_all=validate_all,
                            device=device,
                            stencil=stencil,
                            stencil_tile=stencil_tile,
                            recursive=recursive,
                            permutations_only=permutations_only,
                            expand_reductions=expand_reductions)

        if applied_transformations > 0:
            if debugprint:
                if stencil:
                    print(f"Applied {applied_transformations} TileFusion")
                else:
                    print(f"Applied {applied_transformations} SubgraphFusion")

        if validate_all:
            graph.validate()