def test_diskcache_reuse(): with tempfile.TemporaryDirectory() as cache_dir: # Create dataset and write to cache dataset = lazy_dataset.new(list(range(10))).diskcache( cache_dir=cache_dir, reuse=True, clear=False) list(dataset) del dataset # Assert that the cache didn't get deleted assert Path(cache_dir).exists() # Create a new dataset with the same cache dir and assert that the data # gets loaded from the cache and not from the data pipeline call_counter = 0 def _count_calls(x): nonlocal call_counter call_counter += 1 return x dataset = (lazy_dataset.new(list(range(10))) .map(_count_calls) .diskcache(cache_dir=cache_dir, reuse=True, clear=False)) list(dataset) assert call_counter == 0
def test_key_zip(): ds1 = lazy_dataset.new({'1': 1, '2': 2, '3': 3, '4': 4}) ds2 = lazy_dataset.new({'1': 5, '2': 6, '3': 7, '4': 8}) np.random.seed(2) ds2_shuffled = ds2.shuffle() ds = lazy_dataset.key_zip(ds1, ds2) assert list(ds) == [(1, 5), (2, 6), (3, 7), (4, 8)] ds = lazy_dataset.key_zip(ds1, ds2_shuffled) assert list(ds) == [(1, 5), (2, 6), (3, 7), (4, 8)] ds = lazy_dataset.key_zip(ds1, ds2).prefetch(2, 2) assert list(ds) == [(1, 5), (2, 6), (3, 7), (4, 8)]
def test_diskcache_clear(): dataset = lazy_dataset.new(list(range(10))) cache_dir = tempfile.mkdtemp() dataset = dataset.diskcache(cache_dir=cache_dir, clear=True) list(dataset) assert Path(cache_dir).is_dir() del(dataset) assert not Path(cache_dir).exists()
def test_cache_immutable(): dataset = lazy_dataset.new({ 'a': {'value': 1}, 'b': {'value': 2}, 'c': {'value': 3}, }) cached_dataset = dataset.cache() assert cached_dataset['a']['value'] == 1 cached_dataset['a']['value'] = 42 assert cached_dataset['a']['value'] == 1
def test_diskcache_call_only_once(): call_counter = Counter() dataset = lazy_dataset.new(dict(zip(map(str, range(10)), range(10)))) def m(x): call_counter[x] += 1 return x dataset = dataset.map(m).diskcache() for _ in dataset: pass assert all(v == 1 for v in call_counter.values()) for _ in dataset: pass assert all(v == 1 for v in call_counter.values())
def test_cache_call_only_once(): call_counter = Counter() dataset = lazy_dataset.new(dict(zip(map(str, range(10)), range(10)))) def m(x): call_counter[x] += 1 return x # Set keep_mem_free to a small value to allow testing on machines with # less RAM dataset = dataset.map(m).cache(keep_mem_free='1GB') for _ in dataset: pass assert all(v == 1 for v in call_counter.values()) for _ in dataset: pass assert all(v == 1 for v in call_counter.values())
def test_bucket(): examples = [1, 10, 5, 7, 8, 2, 4, 3, 20, 1, 6, 9] examples = {str(j): i for j, i in enumerate(examples)} ds = lazy_dataset.new(examples) dynamic_batched_buckets = list( ds.batch_dynamic_time_series_bucket(batch_size=1, len_key=lambda x: x, max_padding_rate=0.5)) assert dynamic_batched_buckets == [[1], [10], [5], [7], [8], [2], [4], [3], [20], [1], [6], [9]] dynamic_batched_buckets = list( ds.batch_dynamic_time_series_bucket(batch_size=2, len_key=lambda x: x, max_padding_rate=0.5)) assert dynamic_batched_buckets == [[10, 5], [7, 8], [1, 2], [4, 3], [6, 9], [20], [1]]
def test_cache_mem_percent(): dataset = lazy_dataset.new(dict(zip(map(str, range(100)), range(100)))) available_mem = gb(8) def virtual_memory(): return psutil._pslinux.svmem( total=gb(16), available=available_mem, percent=0, used=0, free=0, active=0, inactive=0, buffers=0, cached=0, shared=0, slab=0 ) with mock.patch('psutil.virtual_memory', new=virtual_memory): ds = dataset.cache(keep_mem_free='50%') available_mem = gb(9) it = iter(ds) assert len(ds._cache) == 0 next(it) assert len(ds._cache) == 1 available_mem = gb(7) with pytest.warns(ResourceWarning, match='Max capacity'): next(it) assert len(ds._cache) == 1
def get_dataset(): examples = get_examples() return lazy_dataset.new(examples)
def get_dataset_predict(): examples = get_examples_predict() return lazy_dataset.new(examples)
def test_unbatch(): examples = OrderedDict(a=[0, 1, 2], b=[3, 4], c=[5, 6, 7]) ds = lazy_dataset.new(examples) ds = ds.unbatch() assert list(ds) == list(range(8))
def test_prefetch(): examples = OrderedDict(a='0_1_2', b='3_4', c='5_6_7') ds = lazy_dataset.new(examples) ds = ds.map(fragment_fn).prefetch(2, 2).unbatch() assert list(ds) == list(range(8))
def test(): def check(ds, expected_str, expected_repr=None): assert_doctest_equal(str(ds), expected_str) if expected_repr is None: repr(ds_dict) else: assert_doctest_equal(repr(ds), expected_repr) ds_dict = lazy_dataset.new({'a': 1, 'b': 2, 'c': 3, 'd': 4}, immutable_warranty='copy') check(ds_dict, 'MapDataset(copy.deepcopy)', ' DictDataset(len=4)\n' 'MapDataset(copy.deepcopy)') ds_dict = lazy_dataset.new({'a': 1, 'b': 2, 'c': 3, 'd': 4}) check(ds_dict, 'MapDataset(_pickle.loads)', ' DictDataset(len=4)\n' 'MapDataset(_pickle.loads)') ds_list = lazy_dataset.new([1, 2, 3, 4]) check(ds_list, 'MapDataset(_pickle.loads)', ' ListDataset(len=4)\n' 'MapDataset(_pickle.loads)') ds = ds_dict.map(lambda ex: ex) check(ds, 'MapDataset(<function test.<locals>.<lambda> at 0x...>)') ds = ds_dict.filter(lambda ex: True) check(ds, 'FilterDataset(<function test.<locals>.<lambda> at 0x...>)') ds = ds_dict.cache(keep_mem_free='5 GiB') check(ds, 'CacheDataset(keep_free=5 GiB)', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' 'CacheDataset(keep_free=5 GiB)') ds = ds_dict.cache(keep_mem_free='5 GiB').copy() check(ds, 'CacheDataset(keep_free=5 GiB)', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' 'CacheDataset(keep_free=5 GiB)') ds = ds_dict.filter(lambda ex: True).cache(lazy=False) check(ds, 'MapDataset(_pickle.loads)', ' ListDataset(len=4)\n' 'MapDataset(_pickle.loads)') ds = ds_dict.diskcache() check(ds, 'DiskCacheDataset(cache_dir=/tmp/diskcache-..., reuse=False)') ds = ds_dict.diskcache().copy() check(ds, 'DiskCacheDataset(cache_dir=/tmp/diskcache-..., reuse=False)') ds = ds_dict.random_choice(2) check(ds, 'SliceDataset([...])') ds = ds_dict.tile(2) check(ds, 'ConcatenateDataset()', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' 'ConcatenateDataset()') ds = ds_dict.batch(2) check(ds, 'BatchDataset(batch_size=2)') ds = ds.unbatch() check(ds, 'UnbatchDataset()') ds = ds_dict.catch() check(ds, 'CatchExceptionDataset()') ds = ds_dict.shuffle(reshuffle=True) check(ds, 'ReShuffleDataset()') ds = ds_dict.shuffle(reshuffle=False) check(ds, 'SliceDataset([...])') ds = ds_dict.sort() # sort by keys check(ds, "SliceDataset(['a', 'b', 'c', 'd'])") ds = ds_list.sort(key_fn=lambda ex: ex) check(ds, 'SliceDataset([0, 1, 2, 3])') ds = ds_dict.zip(ds_list) check(ds, 'ZipDataset()', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' ' ListDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' 'ZipDataset()') import numpy as np np.random.seed(0) ds = ds_dict.key_zip(ds_dict.shuffle()) check(ds, 'KeyZipDataset()', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' ' SliceDataset([2 3 1 0])\n' 'KeyZipDataset()') ds = ds_dict.intersperse(ds_list) check(ds, 'IntersperseDataset()', ' DictDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' ' ListDataset(len=4)\n' ' MapDataset(_pickle.loads)\n' 'IntersperseDataset()')
def test_cache_from_ordered_not_indexable(): dataset = lazy_dataset.new(list(range(10))) assert len(dataset.filter(lambda x: x % 2).cache(lazy=False)) == 5
def test_diskcache_raise_if_cache_exists(): with tempfile.TemporaryDirectory() as cache_dir: (Path(cache_dir) / 'file').touch() with pytest.raises(RuntimeError): lazy_dataset.new(list(range(10))).diskcache( cache_dir=cache_dir, reuse=False)
def test_cache_no_keys(): ds = lazy_dataset.new(list(range(100))).cache() assert list(ds) == list(range(100))