示例#1
0
def test_preservation_interoperability():
    with tempfile.TemporaryDirectory() as d:
        cache = FileCache(10, dir=d, do_pickle=True)

        for i in range(10):
            cache.put(i, str(i))

        assert cache.preserve('preserved') is True

        for i in range(10):
            assert str(i) == cache.get(i)

        cache.close()

        cache2 = MultiprocessFileCache(10, dir=d, do_pickle=True)

        assert cache2.preload('preserved') is True
        for i in range(10):
            assert str(i) == cache2.get(i)
示例#2
0
def test_preservation():
    with tempfile.TemporaryDirectory() as d:
        cache = MultiprocessFileCache(10, dir=d, do_pickle=True)

        for i in range(10):
            cache.put(i, str(i))

        cache.preserve('preserved')

        cache.close()

        # Imitating a new process, fresh load
        cache2 = MultiprocessFileCache(10, dir=d, do_pickle=True)

        cache2.preload('preserved')
        for i in range(10):
            assert str(i) == cache2.get(i)

        cache2.close()

        # No temporary cache file should remain,
        # and the preserved cache should be kept.
        assert os.listdir(d) == ['preserved.cached', 'preserved.cachei']
def test_multiprocess_consistency():
    # Condition: 32k samples (8k*4bytes each) cached by 32 workers.
    # Each sample is an array of repeated sample index.
    # ie. k-th sample is np.array([k, k, k, ..., k], dtype=np.int32)
    # 32 worker processes simultaneously create such data and insert them into
    # a single cache, and we check if the data can be correctly recovered.
    n_workers = 32
    n_samples_per_worker = 1024
    sample_size = 8192

    def child(cache, worker_idx):
        for i in range(n_samples_per_worker):
            sample_idx = worker_idx * n_samples_per_worker + i
            data = np.array([sample_idx] * sample_size, dtype=np.int32)
            cache.put(sample_idx, data)

    with tempfile.TemporaryDirectory() as d:
        cache = MultiprocessFileCache(n_samples_per_worker * n_workers,
                                      dir=d,
                                      do_pickle=True)

        # Add tons of data into the cache in parallel
        ps = [
            multiprocessing.Process(target=child, args=(cache, worker_idx))
            for worker_idx in range(n_workers)
        ]
        for p in ps:
            p.start()
        for p in ps:
            p.join()

        # Get each sample from the cache and check the content
        for sample_idx in range(n_workers * n_samples_per_worker):
            data = cache.get(sample_idx)
            expected = np.array([sample_idx] * sample_size, dtype=np.int32)
            assert (data == expected).all()