def test_no_collision(self): bar1 = self.import_bar1() bar2 = self.import_bar2() with capture_cache_log() as buf: res1 = bar1() cachelog = buf.getvalue() # bar1 should save new index and data self.assertEqual(cachelog.count('index saved'), 1) self.assertEqual(cachelog.count('data saved'), 1) self.assertEqual(cachelog.count('index loaded'), 0) self.assertEqual(cachelog.count('data loaded'), 0) with capture_cache_log() as buf: res2 = bar2() cachelog = buf.getvalue() # bar2 should save new index and data self.assertEqual(cachelog.count('index saved'), 1) self.assertEqual(cachelog.count('data saved'), 1) self.assertEqual(cachelog.count('index loaded'), 0) self.assertEqual(cachelog.count('data loaded'), 0) self.assertNotEqual(res1, res2) try: # Make sure we can spawn new process without inheriting # the parent context. mp = multiprocessing.get_context('spawn') except ValueError: print("missing spawn context") q = mp.Queue() # Start new process that calls `cache_file_collision_tester` proc = mp.Process(target=cache_file_collision_tester, args=(q, self.tempdir, self.modname_bar1, self.modname_bar2)) proc.start() # Get results from the process log1 = q.get() got1 = q.get() log2 = q.get() got2 = q.get() proc.join() # The remote execution result of bar1() and bar2() should match # the one executed locally. self.assertEqual(got1, res1) self.assertEqual(got2, res2) # The remote should have loaded bar1 from cache self.assertEqual(log1.count('index saved'), 0) self.assertEqual(log1.count('data saved'), 0) self.assertEqual(log1.count('index loaded'), 1) self.assertEqual(log1.count('data loaded'), 1) # The remote should have loaded bar2 from cache self.assertEqual(log2.count('index saved'), 0) self.assertEqual(log2.count('data saved'), 0) self.assertEqual(log2.count('index loaded'), 1) self.assertEqual(log2.count('data loaded'), 1)
def cache_file_collision_tester(q, tempdir, modname_bar1, modname_bar2): sys.path.insert(0, tempdir) bar1 = import_dynamic(modname_bar1).bar bar2 = import_dynamic(modname_bar2).bar with capture_cache_log() as buf: r1 = bar1() q.put(buf.getvalue()) q.put(r1) with capture_cache_log() as buf: r2 = bar2() q.put(buf.getvalue()) q.put(r2)
def check_dufunc_usecase(self, usecase_name): mod = self.import_module() usecase = getattr(mod, usecase_name) # Create dufunc with capture_cache_log() as out: ufunc = usecase() self.check_cache_saved(out.getvalue(), count=0) # Compile & cache with capture_cache_log() as out: ufunc(np.arange(10)) self.check_cache_saved(out.getvalue(), count=1) self.check_cache_loaded(out.getvalue(), count=0) # Use cached with capture_cache_log() as out: ufunc = usecase() ufunc(np.arange(10)) self.check_cache_loaded(out.getvalue(), count=1)
def check_ufunc_cache(self, usecase_name, n_overloads, **kwargs): """ Check number of cache load/save. There should be one per overloaded version. """ mod = self.import_module() usecase = getattr(mod, usecase_name) # New cache entry saved with capture_cache_log() as out: new_ufunc = usecase(**kwargs) cachelog = out.getvalue() self.check_cache_saved(cachelog, count=n_overloads) # Use cached version with capture_cache_log() as out: cached_ufunc = usecase(**kwargs) cachelog = out.getvalue() self.check_cache_loaded(cachelog, count=n_overloads) return new_ufunc, cached_ufunc
def test_filename_prefix(self): mod = self.import_module() usecase = getattr(mod, "direct_gufunc_cache_usecase") with capture_cache_log() as out: usecase() cachelog = out.getvalue() # find number filename with "guf-" prefix fmt1 = _fix_raw_path(r'/__pycache__/guf-{}') prefixed = re.findall(fmt1.format(self.modname), cachelog) fmt2 = _fix_raw_path(r'/__pycache__/{}') normal = re.findall(fmt2.format(self.modname), cachelog) # expecting 2 overloads self.assertGreater(len(normal), 2) # expecting equal number of wrappers and overloads cache entries self.assertEqual(len(normal), len(prefixed))