コード例 #1
0
ファイル: test_dataset.py プロジェクト: zhaohengz/armory
def test_resisc45():
    """
    Skip test if not locally available
    """
    if not os.path.isdir(os.path.join(DATASET_DIR, "resisc45_split", "3.0.0")):
        pytest.skip("resisc45_split dataset not locally available.")

    for split, size in [("train", 22500), ("validation", 4500),
                        ("test", 4500)]:
        batch_size = 16
        epochs = 1
        dataset = datasets.resisc45(
            split=split,
            epochs=epochs,
            batch_size=batch_size,
            dataset_dir=DATASET_DIR,
        )
        assert dataset.size == size
        assert dataset.batch_size == batch_size
        assert dataset.batches_per_epoch == (size // batch_size +
                                             bool(size % batch_size))

        x, y = dataset.get_batch()
        assert x.shape == (batch_size, 256, 256, 3)
        assert y.shape == (batch_size, )
コード例 #2
0
def test_pytorch_generator_resisc():
    batch_size = 16
    dataset = datasets.resisc45(
        split_type="train",
        epochs=1,
        batch_size=batch_size,
        dataset_dir=DATASET_DIR,
        framework="pytorch",
    )

    assert isinstance(dataset, torch.utils.data.DataLoader)
    images, labels = next(iter(dataset))
    assert labels.dtype == torch.int64
    assert labels.shape == (batch_size,)

    assert images.dtype == torch.uint8
    assert images.shape == (batch_size, 256, 256, 3)