示例#1
0
    def get_pattern_matches(self,
                            strict=False,
                            states=None,
                            patterns=None,
                            sdfg=None):
        """ Returns all possible transformations for the current SDFG.
            :param strict: Only consider strict transformations (i.e., ones
                           that surely increase performance or enhance
                           readability)
            :param states: An iterable of SDFG states to consider when pattern
                           matching. If None, considers all.
            :param patterns: An iterable of transformation classes to consider
                             when matching. If None, considers all registered
                             transformations in `Transformation`.
            :param sdfg: If not None, searches for patterns on given SDFG.
            :return: List of matching `Transformation` objects.
            @see: Transformation
        """
        sdfg = sdfg or self.sdfg

        if states is None:
            if patterns is None:
                _patterns = self.stateflow_patterns
            else:
                _patterns = [
                    p for p in patterns if p in self.stateflow_patterns
                ]

            for pattern in _patterns:
                yield from pattern_matching.match_stateflow_pattern(
                    sdfg, pattern, strict=strict)

        state_enum = []
        if states is None:
            for state_id, state in enumerate(sdfg.nodes()):
                state_enum.append((state_id, state))
        else:
            for state in states:
                state_id = sdfg.nodes().index(state)
                state_enum.append((state_id, state))

        if patterns is None:
            _patterns = self.patterns
        else:
            _patterns = [p for p in patterns if p in self.patterns]
        for state_id, state in state_enum:
            for pattern in _patterns:
                yield from pattern_matching.match_pattern(state_id,
                                                          state,
                                                          pattern,
                                                          sdfg,
                                                          strict=strict)
示例#2
0
import dace
from dace.transformation.pattern_matching import match_pattern
from dace.transformation.dataflow import MapTiling
import numpy as np


@dace.program
def tile_twice_test(a: dace.float64[200]):
    a *= 2.0


if __name__ == '__main__':
    sdfg = tile_twice_test.to_sdfg()
    sdfg.apply_strict_transformations()
    sdfg.apply_transformations(MapTiling, options={'tile_sizes': (5, )})
    for i, match in enumerate(match_pattern(sdfg.nodes()[0], MapTiling, sdfg)):
        if i == 0:  # Match the first map again
            match.tile_sizes = (4, )
            match.apply_pattern(sdfg)

    A = np.random.rand(200)
    expected = 2 * A

    sdfg(a=A)

    diff = np.linalg.norm(A - expected)
    print('Difference:', diff)
    exit(1 if diff > 1e-8 else 0)
示例#3
0
    A = np.random.rand(256, 256).astype(np.float32)
    B = np.random.rand(256, 256).astype(np.float32)
    expected_C = A @ B
    C = np.zeros((256, 256), dtype=np.float32)

    sdfg = mm_double_buffered.to_sdfg()
    sdfg(A=A, B=B, C=C)

    diff = np.linalg.norm(expected_C - C) / (256 * 256)
    print('Difference (before):', diff)

    # Apply local storage transformation on inner map (last two transformations)
    sdfg.apply_strict_transformations()
    for i in range(2):
        for match in reversed(
                list(match_pattern(sdfg.node(0), InLocalStorage, sdfg))):
            match.apply(sdfg)
            break
        else:
            raise ValueError('Local storage transformation not applied')

    applied = sdfg.apply_transformations(DoubleBuffering)
    if applied != 1:
        raise ValueError('Double-buffering transformation not applied')
    C = np.zeros((256, 256), dtype=np.float32)
    sdfg(A=A, B=B, C=C)

    diff2 = np.linalg.norm(expected_C - C) / (256 * 256)
    print('Difference (after):', diff2)

    exit(1 if (diff > 1e-5 or diff2 > 1e-5) else 0)