Пример #1
0
def fuse_states(sdfg: SDFG) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    counter = 0
    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd, candidate, 0, sd, strict=True):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if config.Config.get_bool('debugprint'):
        print(f'Applied {counter} State Fusions')
    return counter
Пример #2
0
def test_fuse_assignment_in_use():
    """ 
    Two states with an interstate assignment in between, where the assigned
    value is used in the first state. Should fail.
    """
    sdfg = dace.SDFG('state_fusion_test')
    sdfg.add_array('A', [2], dace.int32)
    state1, state2, state3, state4 = tuple(sdfg.add_state() for _ in range(4))
    sdfg.add_edge(state1, state2, dace.InterstateEdge(assignments=dict(k=1)))
    sdfg.add_edge(state2, state3, dace.InterstateEdge())
    sdfg.add_edge(state3, state4, dace.InterstateEdge(assignments=dict(k=2)))

    state3.add_edge(state3.add_tasklet('one', {}, {'a'}, 'a = k'), 'a',
                    state3.add_write('A'), None, dace.Memlet('A[0]'))

    state4.add_edge(state3.add_tasklet('two', {}, {'a'}, 'a = k'), 'a',
                    state3.add_write('A'), None, dace.Memlet('A[1]'))

    try:
        StateFusion.apply_to(sdfg,
                             strict=True,
                             first_state=state3,
                             second_state=state4)
        raise AssertionError('States fused, test failed')
    except ValueError:
        print('Exception successfully caught')
Пример #3
0
def fuse_states(sdfg: SDFG,
                strict: bool = True,
                progress: bool = False) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :param strict: If True (default), operates in strict mode.
    :param progress: If True, prints out a progress bar of fusion (may be
                     inaccurate, requires ``tqdm``)
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    counter = 0
    if progress:
        from tqdm import tqdm
        fusible_states = 0
        for sd in sdfg.all_sdfgs_recursive():
            fusible_states += sd.number_of_edges()
        pbar = tqdm(total=fusible_states)

    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd, candidate, 0, sd, strict=strict):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    if progress:
                        pbar.update(1)
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint'):
        print(f'Applied {counter} State Fusions')
    return counter
Пример #4
0
def fuse_states(sdfg: SDFG,
                permissive: bool = False,
                progress: bool = None) -> int:
    """
    Fuses all possible states of an SDFG (and all sub-SDFGs) using an optimized
    routine that uses the structure of the StateFusion transformation.
    :param sdfg: The SDFG to transform.
    :param permissive: If True, operates in permissive mode, which ignores some
                       race condition checks.
    :param progress: If True, prints out a progress bar of fusion (may be
                     inaccurate, requires ``tqdm``). If None, prints out
                     progress if over 5 seconds have passed. If False, never
                     shows progress bar.
    :return: The total number of states fused.
    """
    from dace.transformation.interstate import StateFusion  # Avoid import loop
    if progress is True or progress is None:
        try:
            from tqdm import tqdm
        except ImportError:
            tqdm = None

    counter = 0
    if progress is True or progress is None:
        fusible_states = 0
        for sd in sdfg.all_sdfgs_recursive():
            fusible_states += sd.number_of_edges()

    if progress is True:
        pbar = tqdm(total=fusible_states, desc='Fusing states')

    start = time.time()

    for sd in sdfg.all_sdfgs_recursive():
        id = sd.sdfg_id
        while True:
            edges = list(sd.nx.edges)
            applied = 0
            skip_nodes = set()
            for u, v in edges:
                if (progress is None and tqdm is not None
                        and (time.time() - start) > 5):
                    progress = True
                    pbar = tqdm(total=fusible_states,
                                desc='Fusing states',
                                initial=counter)

                if u in skip_nodes or v in skip_nodes:
                    continue
                candidate = {
                    StateFusion.first_state: u,
                    StateFusion.second_state: v
                }
                sf = StateFusion(id, -1, candidate, 0, override=True)
                if sf.can_be_applied(sd,
                                     candidate,
                                     0,
                                     sd,
                                     permissive=permissive):
                    sf.apply(sd)
                    applied += 1
                    counter += 1
                    if progress:
                        pbar.update(1)
                    skip_nodes.add(u)
                    skip_nodes.add(v)
            if applied == 0:
                break
    if progress:
        pbar.close()
    if config.Config.get_bool('debugprint') and counter > 0:
        print(f'Applied {counter} State Fusions')
    return counter