Ejemplo n.º 1
0
def remove_unused_sinks(top_sdfg: dace.SDFG):
    """ Remove unused transient sink nodes and their generating
        computation. """
    for sdfg in top_sdfg.all_sdfgs_recursive():
        for state in sdfg.nodes():
            toremove = set()
            map_sink_nodes = [
                n for n in state.nodes() if state.out_degree(n) == 1
                and state.out_edges(n)[0].data.data is None
            ]
            for node in state.sink_nodes() + map_sink_nodes:
                if (isinstance(node, dace.nodes.AccessNode)
                        and sdfg.arrays[node.data].transient):
                    if len([
                            n for s in sdfg.nodes() for n in s.nodes()
                            if isinstance(n, dace.nodes.AccessNode)
                            and n.data == node.data
                    ]) == 1:
                        if state.in_degree(node) == 1:
                            predecessor = state.in_edges(node)[0].src
                            # Only remove the node (and its predecessor) if it
                            # has one unique predecessor that is not connected
                            # to anything else
                            if (state.out_degree(predecessor) == 1
                                    and isinstance(predecessor,
                                                   dace.nodes.CodeNode)):
                                # Also remove potentially isolated input nodes
                                for e in state.in_edges(predecessor):
                                    if len(state.all_edges(e.src)) == 1:
                                        toremove.add(e.src)
                                toremove.add(predecessor)
                                toremove.add(node)

            state.remove_nodes_from(toremove)
Ejemplo n.º 2
0
def remove_scalar_transients(top_sdfg: dace.SDFG):
    """ Clean up tasklet->scalar-transient, replacing them with symbols. """
    dprint = print  # lambda *args: pass
    removed_transients = 0
    for sdfg in top_sdfg.all_sdfgs_recursive():
        transients_to_remove = {}
        for dname, desc in sdfg.arrays.items():
            skip = False
            if isinstance(desc, dace.data.Scalar) and desc.transient:
                # Find node where transient is instantiated
                init_tasklet: Optional[dace.nodes.Tasklet] = None
                itstate = None
                for state in sdfg.nodes():
                    if skip:
                        break
                    for node in state.nodes():
                        if (isinstance(node, dace.nodes.AccessNode)
                                and node.data == dname):
                            if state.in_degree(node) > 1:
                                dprint('Cannot remove scalar', dname,
                                       '(more than one input)')
                                skip = True
                                break
                            elif state.in_degree(node) == 1:
                                if init_tasklet is not None:
                                    dprint('Cannot remove scalar', dname,
                                           '(initialized multiple times)')
                                    skip = True
                                    break
                                init_tasklet = state.in_edges(node)[0].src
                                itstate = state
                if init_tasklet is None:
                    dprint('Cannot remove scalar', dname, '(uninitialized)')
                    skip = True
                if skip:
                    continue

                # We can remove transient, find value from tasklet
                if len(init_tasklet.code.code) != 1:
                    dprint('Cannot remove scalar', dname, '(complex tasklet)')
                    continue
                if not isinstance(init_tasklet.code.code[0], ast.Assign):
                    dprint('Cannot remove scalar', dname, '(complex tasklet2)')
                    continue
                val = float(unparse(init_tasklet.code.code[0].value))

                dprint('Converting', dname, 'to constant with value', val)
                transients_to_remove[dname] = val
                # Remove initialization tasklet
                itstate.remove_node(init_tasklet)

        _remove_transients(sdfg, transients_to_remove)
        removed_transients += len(transients_to_remove)
    print('Cleaned up %d extra scalar transients' % removed_transients)
Ejemplo n.º 3
0
def remove_constant_stencils(top_sdfg: dace.SDFG):
    dprint = print  # lambda *args: pass
    removed_transients = 0
    for sdfg in top_sdfg.all_sdfgs_recursive():
        transients_to_remove = {}
        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, stencil.Stencil)
                        and state.in_degree(node) == 0
                        and state.out_degree(node) == 1):
                    # We can remove transient, find value from tasklet
                    if len(node.code.code) != 1:
                        dprint('Cannot remove scalar stencil', node.name,
                               '(complex code)')
                        continue
                    if not isinstance(node.code.code[0], ast.Assign):
                        dprint('Cannot remove scalar stencil', node.name,
                               '(complex code2)')
                        continue
                    # Ensure no one else is writing to it
                    onode = state.memlet_path(state.out_edges(node)[0])[-1].dst
                    dname = state.out_edges(node)[0].data.data
                    if any(
                            s.in_degree(n) > 0 for s in sdfg.nodes()
                            for n in s.nodes() if n != onode and isinstance(
                                n, dace.nodes.AccessNode) and n.data == dname):

                        continue
                    val = float(eval(unparse(node.code.code[0].value)))

                    dprint('Converting scalar stencil result', dname,
                           'to constant with value', val)
                    transients_to_remove[dname] = val
                    # Remove initialization tasklet
                    state.remove_node(node)

        _remove_transients(sdfg, transients_to_remove, ReplaceSubscript)
        removed_transients += len(transients_to_remove)
    print('Cleaned up %d extra scalar stencils' % removed_transients)
Ejemplo n.º 4
0
    def transfer(sdfg: dace.SDFG, tuner, k: int = 5):
        assert isinstance(tuner, OnTheFlyMapFusionTuner)

        dreport = sdfg.get_instrumented_data()
        assert dreport is not None

        tuning_report = tuner.optimize(apply=False)
        best_configs = cutout_tuner.CutoutTuner.top_k_configs(tuning_report,
                                                              k=k)
        subgraph_patterns = tuner._extract_patterns(best_configs)

        i = 0
        for nsdfg in sdfg.all_sdfgs_recursive():
            for state in nsdfg.states():
                i = i + 1

                top_maps = []
                for node in state.nodes():
                    if isinstance(node,
                                  dace.nodes.MapEntry) and xfh.get_parent_map(
                                      state, node) is None:
                        top_maps.append(node)

                if len(top_maps) < 2:
                    continue

                try:
                    cutout = cutter.cutout_state(state,
                                                 *(state.nodes()),
                                                 make_copy=False)
                except AttributeError:
                    continue

                while True:
                    base_runtime = None
                    best_pattern = None
                    best_pattern_runtime = math.inf
                    for j, pattern in enumerate(subgraph_patterns):
                        maps = []
                        for node in state.nodes():
                            if isinstance(
                                    node, dace.nodes.MapEntry
                            ) and xfh.get_parent_map(state, node) is None:
                                maps.append(node)

                        if len(maps) < 2:
                            continue

                        maps_desc = {}
                        state_desc = Counter()
                        for map_entry in maps:
                            map_desc = OnTheFlyMapFusionTuner.map_descriptor(
                                state, map_entry)
                            state_desc.update({map_desc: 1})

                            if not map_desc in maps_desc:
                                maps_desc[map_desc] = []

                            maps_desc[map_desc].append(map_entry)

                        included = True
                        for key in pattern:
                            if not key in state_desc or pattern[
                                    key] > state_desc[key]:
                                included = False
                                break

                        if not included:
                            continue

                        if base_runtime is None:
                            baseline = cutter.cutout_state(state,
                                                           *(state.nodes()),
                                                           make_copy=False)
                            baseline.start_state.instrument = dace.InstrumentationType.GPU_Events

                            dreport_ = {}
                            for cstate in baseline.nodes():
                                for dnode in cstate.data_nodes():
                                    array = baseline.arrays[dnode.data]
                                    if array.transient:
                                        continue
                                    try:
                                        data = dreport.get_first_version(
                                            dnode.data)
                                        dreport_[dnode.data] = data
                                    except:
                                        continue

                            base_runtime = optim_utils.subprocess_measure(
                                baseline, dreport_, i=192, j=192)
                            best_pattern_runtime = base_runtime
                            if base_runtime == math.inf:
                                break

                        # Construct subgraph greedily
                        subgraph_maps = []
                        for desc in pattern:
                            num = pattern[desc]
                            subgraph_maps.extend(maps_desc[desc][:num])

                        # Apply
                        experiment_sdfg_ = cutter.cutout_state(
                            state, *(state.nodes()), make_copy=False)
                        experiment_state_ = experiment_sdfg_.start_state
                        experiment_maps_ids = list(
                            map(lambda me: experiment_state_.node_id(me),
                                subgraph_maps))
                        experiment_sdfg = copy.deepcopy(experiment_sdfg_)
                        experiment_state = experiment_sdfg.start_state
                        experiment_state.instrument = dace.InstrumentationType.GPU_Events

                        experiment_maps = list(
                            map(lambda m_id: experiment_state.node(m_id),
                                experiment_maps_ids))
                        experiment_subgraph = helpers.subgraph_from_maps(
                            sdfg=experiment_sdfg,
                            graph=experiment_state,
                            map_entries=experiment_maps)

                        map_fusion = sg.SubgraphOTFFusion()
                        map_fusion.setup_match(
                            experiment_subgraph, experiment_sdfg.sdfg_id,
                            experiment_sdfg.node_id(experiment_state))
                        if map_fusion.can_be_applied(experiment_state,
                                                     experiment_sdfg):
                            try:
                                experiment_fuse_counter = map_fusion.apply(
                                    experiment_state, experiment_sdfg)
                            except:
                                continue

                            if experiment_fuse_counter == 0:
                                continue

                            dreport_ = {}
                            for cstate in experiment_sdfg.nodes():
                                for dnode in cstate.data_nodes():
                                    array = experiment_sdfg.arrays[dnode.data]
                                    if array.transient:
                                        continue
                                    try:
                                        data = dreport.get_first_version(
                                            dnode.data)
                                        dreport_[dnode.data] = data
                                    except:
                                        continue

                            fused_runtime = optim_utils.subprocess_measure(
                                experiment_sdfg, dreport_, i=192, j=192)
                            if fused_runtime >= best_pattern_runtime:
                                continue

                            best_pattern = subgraph_maps
                            best_pattern_runtime = fused_runtime

                    if best_pattern is not None:
                        subgraph = helpers.subgraph_from_maps(
                            sdfg=nsdfg, graph=state, map_entries=best_pattern)
                        map_fusion = sg.SubgraphOTFFusion()
                        map_fusion.setup_match(subgraph, nsdfg.sdfg_id,
                                               nsdfg.node_id(state))
                        actual_fuse_counter = map_fusion.apply(state, nsdfg)

                        best_pattern = None
                        base_runtime = None
                        best_pattern_runtime = math.inf
                    else:
                        break