예제 #1
0
def test_lfs_dataref_with_offset_and_shuffle_after_epoch() -> None:
    range_size = 10
    seed = 325
    offset = 15
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(shuffle=True,
                                skip_shuffle_at_epoch_end=False,
                                shuffle_seed=seed,
                                start_offset=offset)
    un_shuffled_keys = list(range(range_size))

    for epoch in range(offset // range_size, 5):
        shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys)
        shuffler = np.random.RandomState(seed + epoch)
        shuffler.shuffle(shuffled_keys_for_epoch)

        if offset // range_size == epoch:
            shuffled_keys_for_epoch = shuffled_keys_for_epoch[offset %
                                                              range_size:]

        data_generator = stream.iterator_fn()
        idx = 0
        for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch):
            assert data == shuffled_key
            idx += 1
        assert idx == len(shuffled_keys_for_epoch)
예제 #2
0
def test_lmdb_access_keys_non_sequential_shard(drop_remainder: bool) -> None:
    range_size = 10
    num_shards = 3
    lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    key_shards = []
    for shard_id in range(num_shards):
        lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path)
        key_shards.append(
            shard_and_get_keys(
                lmdb_reader=lmdb_reader,
                shard_index=shard_id,
                num_shards=num_shards,
                sequential=False,
                drop_shard_remainder=drop_remainder,
            ))

    merged_keys = []
    for idx in range(len(key_shards[0])):
        for key_shard in key_shards:
            if idx < len(key_shard):
                merged_keys.append(key_shard[idx])

    expected_range_size = (range_size if not drop_remainder else range_size -
                           (range_size % num_shards))
    assert len(merged_keys) == expected_range_size
    for idx, key in enumerate(merged_keys):
        assert convert_int_to_byte_string(idx) == key
예제 #3
0
def test_lmdb_access_keys() -> None:
    range_size = 10
    lmdb_reader = yogadl.LmdbAccess(
        lmdb_path=util.create_lmdb_checkpoint_using_range(
            range_size=range_size))
    keys = lmdb_reader.get_keys()
    assert len(keys) == range_size
    for idx, key in enumerate(keys):
        assert convert_int_to_byte_string(idx) == key
예제 #4
0
def test_lmdb_access_read_values() -> None:
    range_size = 10
    lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path)
    keys = lmdb_reader.get_keys()

    for idx, key in enumerate(keys):
        assert lmdb_reader.read_value_by_key(key=key) == idx
예제 #5
0
def test_lmdb_access_shapes_and_types() -> None:
    range_size = 10
    lmdb_reader = yogadl.LmdbAccess(
        lmdb_path=util.create_lmdb_checkpoint_using_range(
            range_size=range_size))
    matching_dataset = tf.data.Dataset.range(range_size)
    assert lmdb_reader.get_shapes() == tf.compat.v1.data.get_output_shapes(
        matching_dataset)
    assert lmdb_reader.get_types() == tf.compat.v1.data.get_output_types(
        matching_dataset)
예제 #6
0
def test_lfs_dataref_from_checkpoint() -> None:
    range_size = 10
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream()

    for _ in range(3):
        idx = 0
        data_generator = stream.iterator_fn()
        for data in data_generator:
            assert data == idx
            idx += 1
        assert idx == range_size
예제 #7
0
def test_lfs_dataref_with_offset() -> None:
    range_size = 10
    offset = 5
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(start_offset=offset)

    for epoch in range(3):
        idx = 5 if epoch == 0 else 0
        data_generator = stream.iterator_fn()
        for data in data_generator:
            assert data == idx
            idx += 1
        assert idx == range_size
예제 #8
0
def test_lfs_dataref_with_shuffle() -> None:
    range_size = 10
    seed = 325
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(shuffle=True,
                                skip_shuffle_at_epoch_end=True,
                                shuffle_seed=seed)
    shuffled_keys = list(range(range_size))
    shuffler = np.random.RandomState(seed)
    shuffler.shuffle(shuffled_keys)

    for _ in range(3):
        data_generator = stream.iterator_fn()
        idx = 0
        for data, shuffled_key in zip(data_generator, shuffled_keys):
            assert data == shuffled_key
            idx += 1
        assert idx == range_size
예제 #9
0
def test_lmdb_access_shuffle() -> None:
    range_size = 10
    seed_one = 41
    seed_two = 421
    lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)

    lmdb_reader_one = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path)
    keys_one = lmdb_reader_one.get_keys()
    keys_one = yogadl.shuffle_keys(keys=keys_one, seed=seed_one)

    lmdb_reader_two = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path)
    keys_two = lmdb_reader_two.get_keys()
    keys_two = yogadl.shuffle_keys(keys=keys_two, seed=seed_one)

    lmdb_reader_three = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path)
    keys_three = lmdb_reader_three.get_keys()
    keys_three = yogadl.shuffle_keys(keys=keys_three, seed=seed_two)

    assert keys_one == keys_two
    assert keys_one != keys_three