def test_lfs_dataref_with_offset_and_shuffle_after_epoch() -> None: range_size = 10 seed = 325 offset = 15 checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) stream = lfs_dataref.stream(shuffle=True, skip_shuffle_at_epoch_end=False, shuffle_seed=seed, start_offset=offset) un_shuffled_keys = list(range(range_size)) for epoch in range(offset // range_size, 5): shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys) shuffler = np.random.RandomState(seed + epoch) shuffler.shuffle(shuffled_keys_for_epoch) if offset // range_size == epoch: shuffled_keys_for_epoch = shuffled_keys_for_epoch[offset % range_size:] data_generator = stream.iterator_fn() idx = 0 for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch): assert data == shuffled_key idx += 1 assert idx == len(shuffled_keys_for_epoch)
def test_lmdb_access_keys_non_sequential_shard(drop_remainder: bool) -> None: range_size = 10 num_shards = 3 lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) key_shards = [] for shard_id in range(num_shards): lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) key_shards.append( shard_and_get_keys( lmdb_reader=lmdb_reader, shard_index=shard_id, num_shards=num_shards, sequential=False, drop_shard_remainder=drop_remainder, )) merged_keys = [] for idx in range(len(key_shards[0])): for key_shard in key_shards: if idx < len(key_shard): merged_keys.append(key_shard[idx]) expected_range_size = (range_size if not drop_remainder else range_size - (range_size % num_shards)) assert len(merged_keys) == expected_range_size for idx, key in enumerate(merged_keys): assert convert_int_to_byte_string(idx) == key
def test_lmdb_access_keys() -> None: range_size = 10 lmdb_reader = yogadl.LmdbAccess( lmdb_path=util.create_lmdb_checkpoint_using_range( range_size=range_size)) keys = lmdb_reader.get_keys() assert len(keys) == range_size for idx, key in enumerate(keys): assert convert_int_to_byte_string(idx) == key
def test_lmdb_access_read_values() -> None: range_size = 10 lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) keys = lmdb_reader.get_keys() for idx, key in enumerate(keys): assert lmdb_reader.read_value_by_key(key=key) == idx
def test_lmdb_access_shapes_and_types() -> None: range_size = 10 lmdb_reader = yogadl.LmdbAccess( lmdb_path=util.create_lmdb_checkpoint_using_range( range_size=range_size)) matching_dataset = tf.data.Dataset.range(range_size) assert lmdb_reader.get_shapes() == tf.compat.v1.data.get_output_shapes( matching_dataset) assert lmdb_reader.get_types() == tf.compat.v1.data.get_output_types( matching_dataset)
def test_lfs_dataref_from_checkpoint() -> None: range_size = 10 checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) stream = lfs_dataref.stream() for _ in range(3): idx = 0 data_generator = stream.iterator_fn() for data in data_generator: assert data == idx idx += 1 assert idx == range_size
def test_lfs_dataref_with_offset() -> None: range_size = 10 offset = 5 checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) stream = lfs_dataref.stream(start_offset=offset) for epoch in range(3): idx = 5 if epoch == 0 else 0 data_generator = stream.iterator_fn() for data in data_generator: assert data == idx idx += 1 assert idx == range_size
def test_lfs_dataref_with_shuffle() -> None: range_size = 10 seed = 325 checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) stream = lfs_dataref.stream(shuffle=True, skip_shuffle_at_epoch_end=True, shuffle_seed=seed) shuffled_keys = list(range(range_size)) shuffler = np.random.RandomState(seed) shuffler.shuffle(shuffled_keys) for _ in range(3): data_generator = stream.iterator_fn() idx = 0 for data, shuffled_key in zip(data_generator, shuffled_keys): assert data == shuffled_key idx += 1 assert idx == range_size
def test_lmdb_access_shuffle() -> None: range_size = 10 seed_one = 41 seed_two = 421 lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range( range_size=range_size) lmdb_reader_one = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) keys_one = lmdb_reader_one.get_keys() keys_one = yogadl.shuffle_keys(keys=keys_one, seed=seed_one) lmdb_reader_two = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) keys_two = lmdb_reader_two.get_keys() keys_two = yogadl.shuffle_keys(keys=keys_two, seed=seed_one) lmdb_reader_three = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) keys_three = lmdb_reader_three.get_keys() keys_three = yogadl.shuffle_keys(keys=keys_three, seed=seed_two) assert keys_one == keys_two assert keys_one != keys_three