def test_cutout_scope_fail(): """ Tests a case in which implicit cutout expansion should fail due to scope mismatch. """ # Prepare graph sdfg = dace.SDFG('complex') sdfg.add_array('A', [20], dace.float64) sdfg.add_array('B', [20], dace.float64) sdfg.add_transient('local', [2], dace.float64) state = sdfg.add_state() a = state.add_read('A') l = state.add_access('local') b = state.add_write('B') # Tiled map ome, omx = state.add_map('somemap', dict(ti='0:20:2')) ime, imx = state.add_map('somemap', dict(i='0:2')) # A tasklet that reads from local memory t = state.add_tasklet('doit', {'a'}, {'o'}, 'o = a + 1') state.add_memlet_path(a, ome, l, memlet=dace.Memlet('A[ti:ti+2]')) state.add_memlet_path(l, ime, t, memlet=dace.Memlet('local[i]'), dst_conn='a') state.add_memlet_path(t, imx, omx, b, memlet=dace.Memlet('B[ti + i]'), src_conn='o') # Cutout (should fail) with pytest.raises(ValueError): cutout.cutout_state(state, t)
def _extract_patterns(self, configs: List[Tuple[str, List[int]]]): # Describe successful fusions as set of map descriptors subgraph_patterns = [] for label, config in configs: nsdfg_id, state_id, _ = label.split(".") nsdfg_id = int(nsdfg_id) state_id = int(state_id) state = list( self._sdfg.all_sdfgs_recursive())[nsdfg_id].node(state_id) nodes = state.nodes() cutout = cutter.cutout_state(state, *(nodes), make_copy=False) pattern_desc = Counter() fusion_id, map_ids = self.config_from_key(config, cutout) if fusion_id == 0: continue for map_id in map_ids: map_entry = cutout.start_state.node(map_id) map_desc = OnTheFlyMapFusionTuner.map_descriptor( cutout.start_state, map_entry) pattern_desc.update({map_desc: 1}) subgraph_patterns.append(pattern_desc) subgraph_patterns = [ dict(s) for s in set( frozenset(d.items()) for d in subgraph_patterns) ] subgraph_patterns = [Counter(s) for s in subgraph_patterns] return subgraph_patterns
def cutouts(self) -> Generator[Tuple[dace.SDFG, str], None, None]: for node, state in self._sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.MapEntry): if xfh.get_parent_map(state, node) is not None: continue node_id = state.node_id(node) state_id = self._sdfg.node_id(state) subgraph_nodes = state.scope_subgraph(node).nodes() cutout = cutter.cutout_state(state, *subgraph_nodes) yield cutout, f"{state_id}.{node_id}.{node.label}"
def cutouts(self): for nsdfg_id, nsdfg in enumerate(self._sdfg.all_sdfgs_recursive()): for state in nsdfg.nodes(): state_id = nsdfg.node_id(state) nodes = state.nodes() try: cutout = cutter.cutout_state(state, *(nodes), make_copy=False) yield cutout, f"{nsdfg_id}.{state_id}.{state.label}" except AttributeError: continue
def test_cutout_onenode(): """ Tests cutout on a single node in a state. """ @dace.program def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]): return A @ B + 5 sdfg = simple_matmul.to_sdfg(simplify=True) assert sdfg.number_of_nodes() == 1 state = sdfg.node(0) assert state.number_of_nodes() == 8 node = next(n for n in state if isinstance(n, dace.nodes.LibraryNode)) cut_sdfg = cutout.cutout_state(state, node) assert cut_sdfg.number_of_nodes() == 1 assert cut_sdfg.node(0).number_of_nodes() == 4 assert len(cut_sdfg.arrays) == 3 assert all(not a.transient for a in cut_sdfg.arrays.values())
def test_cutout_complex_case(): """ Tests cutout on a map with dynamic inputs and two tasklets, which would need two out of three input arrays. """ # Prepare graph sdfg = dace.SDFG('complex') sdfg.add_array('A', [20], dace.float64) sdfg.add_array('B', [20], dace.float64) sdfg.add_array('ind', [2], dace.int32) sdfg.add_array('C', [20], dace.float64) sdfg.add_array('D', [20], dace.float64) state = sdfg.add_state() a = state.add_read('A') b = state.add_read('B') i = state.add_read('ind') c = state.add_write('C') d = state.add_write('D') # Map with dynamic range me, mx = state.add_map('somemap', dict(i='b:e')) me.add_in_connector('b') me.add_in_connector('e') state.add_edge(i, None, me, 'b', dace.Memlet('ind[0]')) state.add_edge(i, None, me, 'e', dace.Memlet('ind[1]')) # Two tasklets, one that reads from A and another from B t1 = state.add_tasklet('doit1', {'a'}, {'o'}, 'o = a + 1') t2 = state.add_tasklet('doit2', {'a'}, {'o'}, 'o = a + 2') state.add_memlet_path(a, me, t1, memlet=dace.Memlet('A[i]'), dst_conn='a') state.add_memlet_path(b, me, t2, memlet=dace.Memlet('B[i]'), dst_conn='a') state.add_memlet_path(t1, mx, c, memlet=dace.Memlet('C[i]'), src_conn='o') state.add_memlet_path(t2, mx, d, memlet=dace.Memlet('D[i]'), src_conn='o') # Cutout cut_sdfg = cutout.cutout_state(state, t2) cut_sdfg.validate() assert cut_sdfg.arrays.keys() == {'B', 'ind', 'D'} # Functionality B = np.random.rand(20) D = np.random.rand(20) ind = np.array([5, 10], dtype=np.int32) cut_sdfg(B=B, D=D, ind=ind) assert not np.allclose(D, B + 2) and np.allclose(D[5:10], B[5:10] + 2)
def test_cutout_multinode(): """ Tests cutout on multiple nodes in a state. """ @dace.program def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]): return A @ B + 5 sdfg = simple_matmul.to_sdfg(simplify=True) assert sdfg.number_of_nodes() == 1 state = sdfg.node(0) assert state.number_of_nodes() == 8 nodes = [ n for n in state if isinstance(n, (dace.nodes.LibraryNode, dace.nodes.Tasklet)) ] assert len(nodes) == 2 cut_sdfg = cutout.cutout_state(state, *nodes) assert cut_sdfg.number_of_nodes() == 1 assert cut_sdfg.node(0).number_of_nodes() == 8 assert len(cut_sdfg.arrays) == 4 assert sum([1 if a.transient else 0 for a in cut_sdfg.arrays.values()]) == 1
def apply(self, config: Tuple[int, List[int]], label: str, **kwargs) -> None: if config[0] == 0: return nsdfg_id, state_id, state_label = label.split(".") nsdfg_id = int(nsdfg_id) state_id = int(state_id) sdfg = list(self._sdfg.all_sdfgs_recursive())[nsdfg_id] state = sdfg.node(state_id) nodes = state.nodes() cutout = cutter.cutout_state(state, *(nodes), make_copy=False) map_ids = config[1] maps_ = list(map(cutout.start_state.node, map_ids)) subgraph = helpers.subgraph_from_maps(sdfg=sdfg, graph=state, map_entries=maps_) map_fusion = sg.SubgraphOTFFusion() map_fusion.setup_match(subgraph, sdfg.sdfg_id, state_id) if map_fusion.can_be_applied(state, sdfg): fuse_counter = map_fusion.apply(state, sdfg) print(f"Fusing {fuse_counter} maps")
def apply(self, config: Tuple[int, List[int]], label: str, **kwargs) -> None: if config[0] == 0: return nsdfg_id, state_id, _ = label.split(".") sdfg = list(self._sdfg.all_sdfgs_recursive())[int(nsdfg_id)] state_id = int(state_id) state = sdfg.node(state_id) nodes = state.nodes() cutout = cutter.cutout_state(state, *(nodes), make_copy=False) map_ids = config[1] maps_ = list(map(cutout.start_state.node, map_ids)) subgraph = helpers.subgraph_from_maps(sdfg=sdfg, graph=state, map_entries=maps_) subgraph_fusion = sg.CompositeFusion() subgraph_fusion.setup_match(subgraph, sdfg.sdfg_id, state_id) subgraph_fusion.allow_tiling = True subgraph_fusion.schedule_innermaps = dace.ScheduleType.GPU_Device if subgraph_fusion.can_be_applied(sdfg, subgraph): subgraph_fusion.apply(sdfg)
def transfer(sdfg: dace.SDFG, tuner, k: int = 5): assert isinstance(tuner, OnTheFlyMapFusionTuner) dreport = sdfg.get_instrumented_data() assert dreport is not None tuning_report = tuner.optimize(apply=False) best_configs = cutout_tuner.CutoutTuner.top_k_configs(tuning_report, k=k) subgraph_patterns = tuner._extract_patterns(best_configs) i = 0 for nsdfg in sdfg.all_sdfgs_recursive(): for state in nsdfg.states(): i = i + 1 top_maps = [] for node in state.nodes(): if isinstance(node, dace.nodes.MapEntry) and xfh.get_parent_map( state, node) is None: top_maps.append(node) if len(top_maps) < 2: continue try: cutout = cutter.cutout_state(state, *(state.nodes()), make_copy=False) except AttributeError: continue while True: base_runtime = None best_pattern = None best_pattern_runtime = math.inf for j, pattern in enumerate(subgraph_patterns): maps = [] for node in state.nodes(): if isinstance( node, dace.nodes.MapEntry ) and xfh.get_parent_map(state, node) is None: maps.append(node) if len(maps) < 2: continue maps_desc = {} state_desc = Counter() for map_entry in maps: map_desc = OnTheFlyMapFusionTuner.map_descriptor( state, map_entry) state_desc.update({map_desc: 1}) if not map_desc in maps_desc: maps_desc[map_desc] = [] maps_desc[map_desc].append(map_entry) included = True for key in pattern: if not key in state_desc or pattern[ key] > state_desc[key]: included = False break if not included: continue if base_runtime is None: baseline = cutter.cutout_state(state, *(state.nodes()), make_copy=False) baseline.start_state.instrument = dace.InstrumentationType.GPU_Events dreport_ = {} for cstate in baseline.nodes(): for dnode in cstate.data_nodes(): array = baseline.arrays[dnode.data] if array.transient: continue try: data = dreport.get_first_version( dnode.data) dreport_[dnode.data] = data except: continue base_runtime = optim_utils.subprocess_measure( baseline, dreport_, i=192, j=192) best_pattern_runtime = base_runtime if base_runtime == math.inf: break # Construct subgraph greedily subgraph_maps = [] for desc in pattern: num = pattern[desc] subgraph_maps.extend(maps_desc[desc][:num]) # Apply experiment_sdfg_ = cutter.cutout_state( state, *(state.nodes()), make_copy=False) experiment_state_ = experiment_sdfg_.start_state experiment_maps_ids = list( map(lambda me: experiment_state_.node_id(me), subgraph_maps)) experiment_sdfg = copy.deepcopy(experiment_sdfg_) experiment_state = experiment_sdfg.start_state experiment_state.instrument = dace.InstrumentationType.GPU_Events experiment_maps = list( map(lambda m_id: experiment_state.node(m_id), experiment_maps_ids)) experiment_subgraph = helpers.subgraph_from_maps( sdfg=experiment_sdfg, graph=experiment_state, map_entries=experiment_maps) map_fusion = sg.SubgraphOTFFusion() map_fusion.setup_match( experiment_subgraph, experiment_sdfg.sdfg_id, experiment_sdfg.node_id(experiment_state)) if map_fusion.can_be_applied(experiment_state, experiment_sdfg): try: experiment_fuse_counter = map_fusion.apply( experiment_state, experiment_sdfg) except: continue if experiment_fuse_counter == 0: continue dreport_ = {} for cstate in experiment_sdfg.nodes(): for dnode in cstate.data_nodes(): array = experiment_sdfg.arrays[dnode.data] if array.transient: continue try: data = dreport.get_first_version( dnode.data) dreport_[dnode.data] = data except: continue fused_runtime = optim_utils.subprocess_measure( experiment_sdfg, dreport_, i=192, j=192) if fused_runtime >= best_pattern_runtime: continue best_pattern = subgraph_maps best_pattern_runtime = fused_runtime if best_pattern is not None: subgraph = helpers.subgraph_from_maps( sdfg=nsdfg, graph=state, map_entries=best_pattern) map_fusion = sg.SubgraphOTFFusion() map_fusion.setup_match(subgraph, nsdfg.sdfg_id, nsdfg.node_id(state)) actual_fuse_counter = map_fusion.apply(state, nsdfg) best_pattern = None base_runtime = None best_pattern_runtime = math.inf else: break