Пример #1
0
def load_cifar10(range_01=False, val=False):
    #
    #
    # Load CIFAR-10 for adaptation with STL
    #
    #

    print('Loading CIFAR-10...')
    if val:
        d_cifar = cifar10.CIFAR10(n_val=5000)
    else:
        d_cifar = cifar10.CIFAR10(n_val=0)

    d_cifar.train_X = d_cifar.train_X[:]
    d_cifar.val_X = d_cifar.val_X[:]
    d_cifar.test_X = d_cifar.test_X[:]
    d_cifar.train_y = d_cifar.train_y[:]
    d_cifar.val_y = d_cifar.val_y[:]
    d_cifar.test_y = d_cifar.test_y[:]

    # Remap class indices so that the frog class (6) has an index of -1 as it does not appear int the STL dataset
    cls_mapping = np.array([0, 1, 2, 3, 4, 5, -1, 6, 7, 8])
    d_cifar.train_y = cls_mapping[d_cifar.train_y]
    d_cifar.val_y = cls_mapping[d_cifar.val_y]
    d_cifar.test_y = cls_mapping[d_cifar.test_y]

    # Remove all samples from skipped classes
    train_mask = d_cifar.train_y != -1
    val_mask = d_cifar.val_y != -1
    test_mask = d_cifar.test_y != -1

    d_cifar.train_X = d_cifar.train_X[train_mask]
    d_cifar.train_y = d_cifar.train_y[train_mask]
    d_cifar.val_X = d_cifar.val_X[val_mask]
    d_cifar.val_y = d_cifar.val_y[val_mask]
    d_cifar.test_X = d_cifar.test_X[test_mask]
    d_cifar.test_y = d_cifar.test_y[test_mask]

    if range_01:
        d_cifar.train_X = d_cifar.train_X * 2.0 - 1.0
        d_cifar.val_X = d_cifar.val_X * 2.0 - 1.0
        d_cifar.test_X = d_cifar.test_X * 2.0 - 1.0

    print(
        'CIFAR-10: train: X.shape={}, y.shape={}, val: X.shape={}, y.shape={}, test: X.shape={}, y.shape={}'
        .format(d_cifar.train_X.shape, d_cifar.train_y.shape,
                d_cifar.val_X.shape, d_cifar.val_y.shape, d_cifar.test_X.shape,
                d_cifar.test_y.shape))

    print('CIFAR-10: train: X.min={}, X.max={}'.format(d_cifar.train_X.min(),
                                                       d_cifar.train_X.max()))

    d_cifar.n_classes = 9

    return d_cifar
Пример #2
0
def test_val_0():
    from batchup.datasets import cifar10

    ds = cifar10.CIFAR10(n_val=0)

    assert ds.train_X.shape == (50000, 3, 32, 32)
    assert ds.train_X.dtype == np.float32

    assert ds.train_y.shape == (50000, )
    assert ds.train_y.dtype == np.int32

    assert ds.val_X.shape == (0, 3, 32, 32)
    assert ds.val_X.dtype == np.float32

    assert ds.val_y.shape == (0, )
    assert ds.val_y.dtype == np.int32

    assert ds.test_X.shape == (10000, 3, 32, 32)
    assert ds.test_X.dtype == np.float32

    assert ds.test_y.shape == (10000, )
    assert ds.test_y.dtype == np.int32

    assert ds.class_names == [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck'
    ]
Пример #3
0
def test_train_test_split():
    from batchup.datasets import cifar10
    from batchup.tests.dataset_test_helpers import sample_hashes

    ds = cifar10.CIFAR10(n_val=0)

    train_h = sample_hashes(ds.train_X_u8)
    test_h = sample_hashes(ds.test_X_u8)

    assert set(train_h).intersection(set(test_h)) == set()