def test_resisc45(): """ Skip test if not locally available """ if not os.path.isdir(os.path.join(DATASET_DIR, "resisc45_split", "3.0.0")): pytest.skip("resisc45_split dataset not locally available.") for split, size in [("train", 22500), ("validation", 4500), ("test", 4500)]: batch_size = 16 epochs = 1 dataset = datasets.resisc45( split=split, epochs=epochs, batch_size=batch_size, dataset_dir=DATASET_DIR, ) assert dataset.size == size assert dataset.batch_size == batch_size assert dataset.batches_per_epoch == (size // batch_size + bool(size % batch_size)) x, y = dataset.get_batch() assert x.shape == (batch_size, 256, 256, 3) assert y.shape == (batch_size, )
def test_pytorch_generator_resisc(): batch_size = 16 dataset = datasets.resisc45( split_type="train", epochs=1, batch_size=batch_size, dataset_dir=DATASET_DIR, framework="pytorch", ) assert isinstance(dataset, torch.utils.data.DataLoader) images, labels = next(iter(dataset)) assert labels.dtype == torch.int64 assert labels.shape == (batch_size,) assert images.dtype == torch.uint8 assert images.shape == (batch_size, 256, 256, 3)