Exemple #1
0
def test_cutout_scope_fail():
    """ Tests a case in which implicit cutout expansion should fail due to scope mismatch. """
    # Prepare graph
    sdfg = dace.SDFG('complex')
    sdfg.add_array('A', [20], dace.float64)
    sdfg.add_array('B', [20], dace.float64)
    sdfg.add_transient('local', [2], dace.float64)

    state = sdfg.add_state()
    a = state.add_read('A')
    l = state.add_access('local')
    b = state.add_write('B')

    # Tiled map
    ome, omx = state.add_map('somemap', dict(ti='0:20:2'))
    ime, imx = state.add_map('somemap', dict(i='0:2'))

    # A tasklet that reads from local memory
    t = state.add_tasklet('doit', {'a'}, {'o'}, 'o = a + 1')
    state.add_memlet_path(a, ome, l, memlet=dace.Memlet('A[ti:ti+2]'))
    state.add_memlet_path(l,
                          ime,
                          t,
                          memlet=dace.Memlet('local[i]'),
                          dst_conn='a')
    state.add_memlet_path(t,
                          imx,
                          omx,
                          b,
                          memlet=dace.Memlet('B[ti + i]'),
                          src_conn='o')

    # Cutout (should fail)
    with pytest.raises(ValueError):
        cutout.cutout_state(state, t)
    def _extract_patterns(self, configs: List[Tuple[str, List[int]]]):
        # Describe successful fusions as set of map descriptors
        subgraph_patterns = []
        for label, config in configs:
            nsdfg_id, state_id, _ = label.split(".")
            nsdfg_id = int(nsdfg_id)
            state_id = int(state_id)
            state = list(
                self._sdfg.all_sdfgs_recursive())[nsdfg_id].node(state_id)
            nodes = state.nodes()
            cutout = cutter.cutout_state(state, *(nodes), make_copy=False)

            pattern_desc = Counter()
            fusion_id, map_ids = self.config_from_key(config, cutout)
            if fusion_id == 0:
                continue

            for map_id in map_ids:
                map_entry = cutout.start_state.node(map_id)
                map_desc = OnTheFlyMapFusionTuner.map_descriptor(
                    cutout.start_state, map_entry)
                pattern_desc.update({map_desc: 1})

            subgraph_patterns.append(pattern_desc)

        subgraph_patterns = [
            dict(s) for s in set(
                frozenset(d.items()) for d in subgraph_patterns)
        ]
        subgraph_patterns = [Counter(s) for s in subgraph_patterns]

        return subgraph_patterns
Exemple #3
0
    def cutouts(self) -> Generator[Tuple[dace.SDFG, str], None, None]:
        for node, state in self._sdfg.all_nodes_recursive():
            if isinstance(node, dace.nodes.MapEntry):
                if xfh.get_parent_map(state, node) is not None:
                    continue

                node_id = state.node_id(node)
                state_id = self._sdfg.node_id(state)
                subgraph_nodes = state.scope_subgraph(node).nodes()
                cutout = cutter.cutout_state(state, *subgraph_nodes)
                yield cutout, f"{state_id}.{node_id}.{node.label}"
    def cutouts(self):
        for nsdfg_id, nsdfg in enumerate(self._sdfg.all_sdfgs_recursive()):
            for state in nsdfg.nodes():
                state_id = nsdfg.node_id(state)
                nodes = state.nodes()

                try:
                    cutout = cutter.cutout_state(state,
                                                 *(nodes),
                                                 make_copy=False)
                    yield cutout, f"{nsdfg_id}.{state_id}.{state.label}"
                except AttributeError:
                    continue
Exemple #5
0
def test_cutout_onenode():
    """ Tests cutout on a single node in a state. """
    @dace.program
    def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]):
        return A @ B + 5

    sdfg = simple_matmul.to_sdfg(simplify=True)
    assert sdfg.number_of_nodes() == 1
    state = sdfg.node(0)
    assert state.number_of_nodes() == 8
    node = next(n for n in state if isinstance(n, dace.nodes.LibraryNode))

    cut_sdfg = cutout.cutout_state(state, node)
    assert cut_sdfg.number_of_nodes() == 1
    assert cut_sdfg.node(0).number_of_nodes() == 4
    assert len(cut_sdfg.arrays) == 3
    assert all(not a.transient for a in cut_sdfg.arrays.values())
Exemple #6
0
def test_cutout_complex_case():
    """ Tests cutout on a map with dynamic inputs and two tasklets, which would need two out of three input arrays. """
    # Prepare graph
    sdfg = dace.SDFG('complex')
    sdfg.add_array('A', [20], dace.float64)
    sdfg.add_array('B', [20], dace.float64)
    sdfg.add_array('ind', [2], dace.int32)
    sdfg.add_array('C', [20], dace.float64)
    sdfg.add_array('D', [20], dace.float64)

    state = sdfg.add_state()
    a = state.add_read('A')
    b = state.add_read('B')
    i = state.add_read('ind')
    c = state.add_write('C')
    d = state.add_write('D')

    # Map with dynamic range
    me, mx = state.add_map('somemap', dict(i='b:e'))
    me.add_in_connector('b')
    me.add_in_connector('e')
    state.add_edge(i, None, me, 'b', dace.Memlet('ind[0]'))
    state.add_edge(i, None, me, 'e', dace.Memlet('ind[1]'))

    # Two tasklets, one that reads from A and another from B
    t1 = state.add_tasklet('doit1', {'a'}, {'o'}, 'o = a + 1')
    t2 = state.add_tasklet('doit2', {'a'}, {'o'}, 'o = a + 2')
    state.add_memlet_path(a, me, t1, memlet=dace.Memlet('A[i]'), dst_conn='a')
    state.add_memlet_path(b, me, t2, memlet=dace.Memlet('B[i]'), dst_conn='a')
    state.add_memlet_path(t1, mx, c, memlet=dace.Memlet('C[i]'), src_conn='o')
    state.add_memlet_path(t2, mx, d, memlet=dace.Memlet('D[i]'), src_conn='o')

    # Cutout
    cut_sdfg = cutout.cutout_state(state, t2)
    cut_sdfg.validate()
    assert cut_sdfg.arrays.keys() == {'B', 'ind', 'D'}

    # Functionality
    B = np.random.rand(20)
    D = np.random.rand(20)
    ind = np.array([5, 10], dtype=np.int32)
    cut_sdfg(B=B, D=D, ind=ind)
    assert not np.allclose(D, B + 2) and np.allclose(D[5:10], B[5:10] + 2)
Exemple #7
0
def test_cutout_multinode():
    """ Tests cutout on multiple nodes in a state. """
    @dace.program
    def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]):
        return A @ B + 5

    sdfg = simple_matmul.to_sdfg(simplify=True)
    assert sdfg.number_of_nodes() == 1
    state = sdfg.node(0)
    assert state.number_of_nodes() == 8
    nodes = [
        n for n in state
        if isinstance(n, (dace.nodes.LibraryNode, dace.nodes.Tasklet))
    ]
    assert len(nodes) == 2

    cut_sdfg = cutout.cutout_state(state, *nodes)
    assert cut_sdfg.number_of_nodes() == 1
    assert cut_sdfg.node(0).number_of_nodes() == 8
    assert len(cut_sdfg.arrays) == 4
    assert sum([1 if a.transient else 0
                for a in cut_sdfg.arrays.values()]) == 1
    def apply(self, config: Tuple[int, List[int]], label: str,
              **kwargs) -> None:
        if config[0] == 0:
            return

        nsdfg_id, state_id, state_label = label.split(".")
        nsdfg_id = int(nsdfg_id)
        state_id = int(state_id)
        sdfg = list(self._sdfg.all_sdfgs_recursive())[nsdfg_id]
        state = sdfg.node(state_id)
        nodes = state.nodes()
        cutout = cutter.cutout_state(state, *(nodes), make_copy=False)

        map_ids = config[1]
        maps_ = list(map(cutout.start_state.node, map_ids))
        subgraph = helpers.subgraph_from_maps(sdfg=sdfg,
                                              graph=state,
                                              map_entries=maps_)

        map_fusion = sg.SubgraphOTFFusion()
        map_fusion.setup_match(subgraph, sdfg.sdfg_id, state_id)
        if map_fusion.can_be_applied(state, sdfg):
            fuse_counter = map_fusion.apply(state, sdfg)
            print(f"Fusing {fuse_counter} maps")
    def apply(self, config: Tuple[int, List[int]], label: str,
              **kwargs) -> None:
        if config[0] == 0:
            return

        nsdfg_id, state_id, _ = label.split(".")
        sdfg = list(self._sdfg.all_sdfgs_recursive())[int(nsdfg_id)]
        state_id = int(state_id)
        state = sdfg.node(state_id)
        nodes = state.nodes()
        cutout = cutter.cutout_state(state, *(nodes), make_copy=False)

        map_ids = config[1]
        maps_ = list(map(cutout.start_state.node, map_ids))
        subgraph = helpers.subgraph_from_maps(sdfg=sdfg,
                                              graph=state,
                                              map_entries=maps_)

        subgraph_fusion = sg.CompositeFusion()
        subgraph_fusion.setup_match(subgraph, sdfg.sdfg_id, state_id)
        subgraph_fusion.allow_tiling = True
        subgraph_fusion.schedule_innermaps = dace.ScheduleType.GPU_Device
        if subgraph_fusion.can_be_applied(sdfg, subgraph):
            subgraph_fusion.apply(sdfg)
    def transfer(sdfg: dace.SDFG, tuner, k: int = 5):
        assert isinstance(tuner, OnTheFlyMapFusionTuner)

        dreport = sdfg.get_instrumented_data()
        assert dreport is not None

        tuning_report = tuner.optimize(apply=False)
        best_configs = cutout_tuner.CutoutTuner.top_k_configs(tuning_report,
                                                              k=k)
        subgraph_patterns = tuner._extract_patterns(best_configs)

        i = 0
        for nsdfg in sdfg.all_sdfgs_recursive():
            for state in nsdfg.states():
                i = i + 1

                top_maps = []
                for node in state.nodes():
                    if isinstance(node,
                                  dace.nodes.MapEntry) and xfh.get_parent_map(
                                      state, node) is None:
                        top_maps.append(node)

                if len(top_maps) < 2:
                    continue

                try:
                    cutout = cutter.cutout_state(state,
                                                 *(state.nodes()),
                                                 make_copy=False)
                except AttributeError:
                    continue

                while True:
                    base_runtime = None
                    best_pattern = None
                    best_pattern_runtime = math.inf
                    for j, pattern in enumerate(subgraph_patterns):
                        maps = []
                        for node in state.nodes():
                            if isinstance(
                                    node, dace.nodes.MapEntry
                            ) and xfh.get_parent_map(state, node) is None:
                                maps.append(node)

                        if len(maps) < 2:
                            continue

                        maps_desc = {}
                        state_desc = Counter()
                        for map_entry in maps:
                            map_desc = OnTheFlyMapFusionTuner.map_descriptor(
                                state, map_entry)
                            state_desc.update({map_desc: 1})

                            if not map_desc in maps_desc:
                                maps_desc[map_desc] = []

                            maps_desc[map_desc].append(map_entry)

                        included = True
                        for key in pattern:
                            if not key in state_desc or pattern[
                                    key] > state_desc[key]:
                                included = False
                                break

                        if not included:
                            continue

                        if base_runtime is None:
                            baseline = cutter.cutout_state(state,
                                                           *(state.nodes()),
                                                           make_copy=False)
                            baseline.start_state.instrument = dace.InstrumentationType.GPU_Events

                            dreport_ = {}
                            for cstate in baseline.nodes():
                                for dnode in cstate.data_nodes():
                                    array = baseline.arrays[dnode.data]
                                    if array.transient:
                                        continue
                                    try:
                                        data = dreport.get_first_version(
                                            dnode.data)
                                        dreport_[dnode.data] = data
                                    except:
                                        continue

                            base_runtime = optim_utils.subprocess_measure(
                                baseline, dreport_, i=192, j=192)
                            best_pattern_runtime = base_runtime
                            if base_runtime == math.inf:
                                break

                        # Construct subgraph greedily
                        subgraph_maps = []
                        for desc in pattern:
                            num = pattern[desc]
                            subgraph_maps.extend(maps_desc[desc][:num])

                        # Apply
                        experiment_sdfg_ = cutter.cutout_state(
                            state, *(state.nodes()), make_copy=False)
                        experiment_state_ = experiment_sdfg_.start_state
                        experiment_maps_ids = list(
                            map(lambda me: experiment_state_.node_id(me),
                                subgraph_maps))
                        experiment_sdfg = copy.deepcopy(experiment_sdfg_)
                        experiment_state = experiment_sdfg.start_state
                        experiment_state.instrument = dace.InstrumentationType.GPU_Events

                        experiment_maps = list(
                            map(lambda m_id: experiment_state.node(m_id),
                                experiment_maps_ids))
                        experiment_subgraph = helpers.subgraph_from_maps(
                            sdfg=experiment_sdfg,
                            graph=experiment_state,
                            map_entries=experiment_maps)

                        map_fusion = sg.SubgraphOTFFusion()
                        map_fusion.setup_match(
                            experiment_subgraph, experiment_sdfg.sdfg_id,
                            experiment_sdfg.node_id(experiment_state))
                        if map_fusion.can_be_applied(experiment_state,
                                                     experiment_sdfg):
                            try:
                                experiment_fuse_counter = map_fusion.apply(
                                    experiment_state, experiment_sdfg)
                            except:
                                continue

                            if experiment_fuse_counter == 0:
                                continue

                            dreport_ = {}
                            for cstate in experiment_sdfg.nodes():
                                for dnode in cstate.data_nodes():
                                    array = experiment_sdfg.arrays[dnode.data]
                                    if array.transient:
                                        continue
                                    try:
                                        data = dreport.get_first_version(
                                            dnode.data)
                                        dreport_[dnode.data] = data
                                    except:
                                        continue

                            fused_runtime = optim_utils.subprocess_measure(
                                experiment_sdfg, dreport_, i=192, j=192)
                            if fused_runtime >= best_pattern_runtime:
                                continue

                            best_pattern = subgraph_maps
                            best_pattern_runtime = fused_runtime

                    if best_pattern is not None:
                        subgraph = helpers.subgraph_from_maps(
                            sdfg=nsdfg, graph=state, map_entries=best_pattern)
                        map_fusion = sg.SubgraphOTFFusion()
                        map_fusion.setup_match(subgraph, nsdfg.sdfg_id,
                                               nsdfg.node_id(state))
                        actual_fuse_counter = map_fusion.apply(state, nsdfg)

                        best_pattern = None
                        base_runtime = None
                        best_pattern_runtime = math.inf
                    else:
                        break