def test_preservation_interoperability(): with tempfile.TemporaryDirectory() as d: cache = FileCache(10, dir=d, do_pickle=True) for i in range(10): cache.put(i, str(i)) assert cache.preserve('preserved') is True for i in range(10): assert str(i) == cache.get(i) cache.close() cache2 = MultiprocessFileCache(10, dir=d, do_pickle=True) assert cache2.preload('preserved') is True for i in range(10): assert str(i) == cache2.get(i)
def test_preservation(): with tempfile.TemporaryDirectory() as d: cache = MultiprocessFileCache(10, dir=d, do_pickle=True) for i in range(10): cache.put(i, str(i)) cache.preserve('preserved') cache.close() # Imitating a new process, fresh load cache2 = MultiprocessFileCache(10, dir=d, do_pickle=True) cache2.preload('preserved') for i in range(10): assert str(i) == cache2.get(i) cache2.close() # No temporary cache file should remain, # and the preserved cache should be kept. assert os.listdir(d) == ['preserved.cached', 'preserved.cachei']
def test_multiprocess_consistency(): # Condition: 32k samples (8k*4bytes each) cached by 32 workers. # Each sample is an array of repeated sample index. # ie. k-th sample is np.array([k, k, k, ..., k], dtype=np.int32) # 32 worker processes simultaneously create such data and insert them into # a single cache, and we check if the data can be correctly recovered. n_workers = 32 n_samples_per_worker = 1024 sample_size = 8192 def child(cache, worker_idx): for i in range(n_samples_per_worker): sample_idx = worker_idx * n_samples_per_worker + i data = np.array([sample_idx] * sample_size, dtype=np.int32) cache.put(sample_idx, data) with tempfile.TemporaryDirectory() as d: cache = MultiprocessFileCache(n_samples_per_worker * n_workers, dir=d, do_pickle=True) # Add tons of data into the cache in parallel ps = [ multiprocessing.Process(target=child, args=(cache, worker_idx)) for worker_idx in range(n_workers) ] for p in ps: p.start() for p in ps: p.join() # Get each sample from the cache and check the content for sample_idx in range(n_workers * n_samples_per_worker): data = cache.get(sample_idx) expected = np.array([sample_idx] * sample_size, dtype=np.int32) assert (data == expected).all()