def test_lfs_dataref_with_offset_and_shuffle_after_epoch() -> None:
    range_size = 10
    seed = 325
    offset = 15
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(shuffle=True,
                                skip_shuffle_at_epoch_end=False,
                                shuffle_seed=seed,
                                start_offset=offset)
    un_shuffled_keys = list(range(range_size))

    for epoch in range(offset // range_size, 5):
        shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys)
        shuffler = np.random.RandomState(seed + epoch)
        shuffler.shuffle(shuffled_keys_for_epoch)

        if offset // range_size == epoch:
            shuffled_keys_for_epoch = shuffled_keys_for_epoch[offset %
                                                              range_size:]

        data_generator = stream.iterator_fn()
        idx = 0
        for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch):
            assert data == shuffled_key
            idx += 1
        assert idx == len(shuffled_keys_for_epoch)
Exemple #2
0
    def fetch(self, dataset_id: str,
              dataset_version: str) -> dataref.LMDBDataRef:
        """
        Fetch a dataset from storage and provide a DataRef
        for streaming it.
        """
        cache_filepath = self._get_cache_filepath(
            dataset_id=dataset_id, dataset_version=dataset_version)
        assert cache_filepath.exists()

        return dataref.LMDBDataRef(cache_filepath=cache_filepath)
def test_lfs_dataref_from_checkpoint() -> None:
    range_size = 10
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream()

    for _ in range(3):
        idx = 0
        data_generator = stream.iterator_fn()
        for data in data_generator:
            assert data == idx
            idx += 1
        assert idx == range_size
def test_lfs_dataref_with_offset() -> None:
    range_size = 10
    offset = 5
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(start_offset=offset)

    for epoch in range(3):
        idx = 5 if epoch == 0 else 0
        data_generator = stream.iterator_fn()
        for data in data_generator:
            assert data == idx
            idx += 1
        assert idx == range_size
    def fetch(self, dataset_id: str, dataset_version: str) -> dataref.LMDBDataRef:
        """
        Fetch a dataset from cloud storage and provide a DataRef for streaming it.

        The timestamp of the cache in cloud storage is compared to the creation
        time of the local cache, if they are not identical, the local cache
        is overwritten.

        `fetch()` is not safe for concurrent accesses. For concurrent accesses use
        `cacheable()`.
        """

        local_metadata = self._get_local_metadata(
            dataset_id=dataset_id, dataset_version=dataset_version
        )
        local_cache_filepath = self._get_local_cache_filepath(
            dataset_id=dataset_id,
            dataset_version=dataset_version,
        )

        remote_cache_timestamp = self._get_remote_cache_timestamp(
            dataset_id=dataset_id, dataset_version=dataset_version
        ).timestamp()

        if local_metadata.get("time_created") == remote_cache_timestamp:
            logging.info("Local cache matches remote cache.")
        else:
            logging.info(f"Downloading remote cache to {local_cache_filepath}.")
            local_metadata["time_created"] = self._download_from_cloud_storage(
                dataset_id=dataset_id,
                dataset_version=dataset_version,
                local_cache_filepath=local_cache_filepath,
            ).timestamp()
            logging.info("Cache download finished.")

            self._save_local_metadata(
                dataset_id=dataset_id,
                dataset_version=dataset_version,
                metadata=local_metadata,
            )

        assert local_cache_filepath.exists()

        return dataref.LMDBDataRef(cache_filepath=local_cache_filepath)
def test_lfs_dataref_with_shuffle() -> None:
    range_size = 10
    seed = 325
    checkpoint_path = util.create_lmdb_checkpoint_using_range(
        range_size=range_size)
    lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path)
    stream = lfs_dataref.stream(shuffle=True,
                                skip_shuffle_at_epoch_end=True,
                                shuffle_seed=seed)
    shuffled_keys = list(range(range_size))
    shuffler = np.random.RandomState(seed)
    shuffler.shuffle(shuffled_keys)

    for _ in range(3):
        data_generator = stream.iterator_fn()
        idx = 0
        for data, shuffled_key in zip(data_generator, shuffled_keys):
            assert data == shuffled_key
            idx += 1
        assert idx == range_size