Exemple #1
0
    def test_dynamic_nthreads(self):
        grid = Grid(shape=(16, 16, 16))
        f = TimeFunction(name='f', grid=grid)
        sf = SparseTimeFunction(name='sf', grid=grid, npoint=1, nt=5)

        eqns = [Eq(f.forward, f + 1)]
        eqns += sf.interpolate(f)

        op = Operator(eqns, opt='openmp')

        parregions = FindNodes(ParallelRegion).visit(op)
        assert len(parregions) == 2

        # Check suitable `num_threads` appear in the generated code
        # Not very elegant, but it does the trick
        assert 'num_threads(nthreads)' in str(parregions[0].header[0])
        assert 'num_threads(nthreads_nonaffine)' in str(parregions[1].header[0])

        # Check `op` accepts the `nthreads*` kwargs
        op.apply(time=0)
        op.apply(time_m=1, time_M=1, nthreads=4)
        op.apply(time_m=1, time_M=1, nthreads=4, nthreads_nonaffine=2)
        op.apply(time_m=1, time_M=1, nthreads_nonaffine=2)
        assert np.all(f.data[0] == 2.)

        # Check the actual value assumed by `nthreads` and `nthreads_nonaffine`
        assert op.arguments(time=0)['nthreads'] == NThreads.default_value()
        assert op.arguments(time=0)['nthreads_nonaffine'] == \
            NThreadsNonaffine.default_value()
        # Again, but with user-supplied values
        assert op.arguments(time=0, nthreads=123)['nthreads'] == 123
        assert op.arguments(time=0, nthreads_nonaffine=100)['nthreads_nonaffine'] == 100
        # Again, but with the aliases
        assert op.arguments(time=0, nthreads0=123)['nthreads'] == 123
        assert op.arguments(time=0, nthreads2=123)['nthreads_nonaffine'] == 123
Exemple #2
0
    def _symbol_registry(cls):
        # Special symbols an Operator might use
        nthreads = NThreads(aliases='nthreads0')
        nthreads_nested = NThreadsNested(aliases='nthreads1')
        nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2')
        threadid = CustomDimension(name='tid', symbolic_size=nthreads)

        return SymbolRegistry(nthreads, nthreads_nested, nthreads_nonaffine, threadid)