def test_dynamic_nthreads(self): grid = Grid(shape=(16, 16, 16)) f = TimeFunction(name='f', grid=grid) sf = SparseTimeFunction(name='sf', grid=grid, npoint=1, nt=5) eqns = [Eq(f.forward, f + 1)] eqns += sf.interpolate(f) op = Operator(eqns, opt='openmp') parregions = FindNodes(ParallelRegion).visit(op) assert len(parregions) == 2 # Check suitable `num_threads` appear in the generated code # Not very elegant, but it does the trick assert 'num_threads(nthreads)' in str(parregions[0].header[0]) assert 'num_threads(nthreads_nonaffine)' in str(parregions[1].header[0]) # Check `op` accepts the `nthreads*` kwargs op.apply(time=0) op.apply(time_m=1, time_M=1, nthreads=4) op.apply(time_m=1, time_M=1, nthreads=4, nthreads_nonaffine=2) op.apply(time_m=1, time_M=1, nthreads_nonaffine=2) assert np.all(f.data[0] == 2.) # Check the actual value assumed by `nthreads` and `nthreads_nonaffine` assert op.arguments(time=0)['nthreads'] == NThreads.default_value() assert op.arguments(time=0)['nthreads_nonaffine'] == \ NThreadsNonaffine.default_value() # Again, but with user-supplied values assert op.arguments(time=0, nthreads=123)['nthreads'] == 123 assert op.arguments(time=0, nthreads_nonaffine=100)['nthreads_nonaffine'] == 100 # Again, but with the aliases assert op.arguments(time=0, nthreads0=123)['nthreads'] == 123 assert op.arguments(time=0, nthreads2=123)['nthreads_nonaffine'] == 123
def _symbol_registry(cls): # Special symbols an Operator might use nthreads = NThreads(aliases='nthreads0') nthreads_nested = NThreadsNested(aliases='nthreads1') nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2') threadid = CustomDimension(name='tid', symbolic_size=nthreads) return SymbolRegistry(nthreads, nthreads_nested, nthreads_nonaffine, threadid)