def fuse_states(sdfg: SDFG) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop counter = 0 for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, strict=True): sf.apply(sd) applied += 1 counter += 1 skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if config.Config.get_bool('debugprint'): print(f'Applied {counter} State Fusions') return counter
def test_fuse_assignment_in_use(): """ Two states with an interstate assignment in between, where the assigned value is used in the first state. Should fail. """ sdfg = dace.SDFG('state_fusion_test') sdfg.add_array('A', [2], dace.int32) state1, state2, state3, state4 = tuple(sdfg.add_state() for _ in range(4)) sdfg.add_edge(state1, state2, dace.InterstateEdge(assignments=dict(k=1))) sdfg.add_edge(state2, state3, dace.InterstateEdge()) sdfg.add_edge(state3, state4, dace.InterstateEdge(assignments=dict(k=2))) state3.add_edge(state3.add_tasklet('one', {}, {'a'}, 'a = k'), 'a', state3.add_write('A'), None, dace.Memlet('A[0]')) state4.add_edge(state3.add_tasklet('two', {}, {'a'}, 'a = k'), 'a', state3.add_write('A'), None, dace.Memlet('A[1]')) try: StateFusion.apply_to(sdfg, strict=True, first_state=state3, second_state=state4) raise AssertionError('States fused, test failed') except ValueError: print('Exception successfully caught')
def fuse_states(sdfg: SDFG, strict: bool = True, progress: bool = False) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :param strict: If True (default), operates in strict mode. :param progress: If True, prints out a progress bar of fusion (may be inaccurate, requires ``tqdm``) :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop counter = 0 if progress: from tqdm import tqdm fusible_states = 0 for sd in sdfg.all_sdfgs_recursive(): fusible_states += sd.number_of_edges() pbar = tqdm(total=fusible_states) for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, strict=strict): sf.apply(sd) applied += 1 counter += 1 if progress: pbar.update(1) skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if progress: pbar.close() if config.Config.get_bool('debugprint'): print(f'Applied {counter} State Fusions') return counter
def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: """ Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized routine that uses the structure of the StateFusion transformation. :param sdfg: The SDFG to transform. :param permissive: If True, operates in permissive mode, which ignores some race condition checks. :param progress: If True, prints out a progress bar of fusion (may be inaccurate, requires ``tqdm``). If None, prints out progress if over 5 seconds have passed. If False, never shows progress bar. :return: The total number of states fused. """ from dace.transformation.interstate import StateFusion # Avoid import loop if progress is True or progress is None: try: from tqdm import tqdm except ImportError: tqdm = None counter = 0 if progress is True or progress is None: fusible_states = 0 for sd in sdfg.all_sdfgs_recursive(): fusible_states += sd.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') start = time.time() for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id while True: edges = list(sd.nx.edges) applied = 0 skip_nodes = set() for u, v in edges: if (progress is None and tqdm is not None and (time.time() - start) > 5): progress = True pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) if u in skip_nodes or v in skip_nodes: continue candidate = { StateFusion.first_state: u, StateFusion.second_state: v } sf = StateFusion(id, -1, candidate, 0, override=True) if sf.can_be_applied(sd, candidate, 0, sd, permissive=permissive): sf.apply(sd) applied += 1 counter += 1 if progress: pbar.update(1) skip_nodes.add(u) skip_nodes.add(v) if applied == 0: break if progress: pbar.close() if config.Config.get_bool('debugprint') and counter > 0: print(f'Applied {counter} State Fusions') return counter