def _symbol_registry(cls): # Special symbols an Operator might use nthreads = NThreads(aliases='nthreads0') nthreads_nested = NThreadsNested(aliases='nthreads1') nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2') threadid = ThreadID(nthreads) return SymbolRegistry(nthreads, nthreads_nested, nthreads_nonaffine, threadid)
def __init__(self): # {name -> generator()} -- to create unique names for symbols, functions, ... self.counters = {} # Special symbols self.nthreads = NThreads(aliases='nthreads0') self.nthreads_nested = NThreadsNested(aliases='nthreads1') self.nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2') self.threadid = ThreadID(self.nthreads) # Several groups of pthreads each of size `npthread` may be created # during compilation self.npthreads = []
def test_threadid(): grid = Grid(shape=(4, 4, 4)) f = TimeFunction(name='f', grid=grid) op = Operator(Eq(f.forward, f + 1.), openmp=True) tid = ThreadID(op.nthreads) pkl_tid = pickle.dumps(tid) new_tid = pickle.loads(pkl_tid) assert tid.name == new_tid.name assert tid.nthreads.name == new_tid.nthreads.name assert tid.symbolic_min.name == new_tid.symbolic_min.name assert tid.symbolic_max.name == new_tid.symbolic_max.name
def test_special_symbols(self): """ This test checks the singletonization, through the caching infrastructure, of the special symbols that an Operator may generate (e.g., `nthreads`). """ grid = Grid(shape=(4, 4, 4)) f = TimeFunction(name='f', grid=grid) sf = SparseTimeFunction(name='sf', grid=grid, npoint=1, nt=10) eqns = [Eq(f.forward, f + 1.)] + sf.inject(field=f.forward, expr=sf) opt = ('advanced', {'par-nested': 0, 'openmp': True}) op0 = Operator(eqns, opt=opt) op1 = Operator(eqns, opt=opt) nthreads0, nthreads_nested0, nthreads_nonaffine0 =\ [i for i in op0.input if isinstance(i, NThreadsBase)] nthreads1, nthreads_nested1, nthreads_nonaffine1 =\ [i for i in op1.input if isinstance(i, NThreadsBase)] assert nthreads0 is nthreads1 assert nthreads_nested0 is nthreads_nested1 assert nthreads_nonaffine0 is nthreads_nonaffine1 tid0 = ThreadID(op0.nthreads) tid1 = ThreadID(op0.nthreads) assert tid0 is tid1 did0 = DeviceID() did1 = DeviceID() assert did0 is did1 npt0 = NPThreads(name='npt', size=3) npt1 = NPThreads(name='npt', size=3) npt2 = NPThreads(name='npt', size=4) assert npt0 is npt1 assert npt0 is not npt2