示例#1
0
def test_val_0():
    from batchup.datasets import stl

    ds = stl.STL(n_val_folds=0)

    assert ds.train_X_u8.shape == (5000, 3, 96, 96)
    assert ds.train_X_u8.dtype == np.uint8

    assert ds.train_y.shape == (5000, )
    assert ds.train_y.dtype == np.int32

    assert ds.val_X_u8.shape == (0, 3, 96, 96)
    assert ds.val_X_u8.dtype == np.uint8

    assert ds.val_y.shape == (0, )
    assert ds.val_y.dtype == np.int32

    assert ds.test_X_u8.shape == (8000, 3, 96, 96)
    assert ds.test_X_u8.dtype == np.uint8

    assert ds.test_y.shape == (8000, )
    assert ds.test_y.dtype == np.int32

    assert ds.class_names == [
        'airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey',
        'ship', 'truck'
    ]
示例#2
0
def test_train_test_split():
    from batchup.datasets import stl
    from batchup.tests.dataset_test_helpers import sample_hashes

    ds = stl.STL(n_val_folds=1)

    train_h = sample_hashes(ds.train_X_u8)
    test_h = sample_hashes(ds.test_X_u8)

    assert set(train_h).intersection(set(test_h)) == set()
示例#3
0
def load_stl(zero_centre=False, val=False):
    #
    #
    # Load STL for adaptation with CIFAR-10
    #
    #

    print('Loading STL...')
    if val:
        d_stl = stl.STL()
    else:
        d_stl = stl.STL(n_val_folds=0)

    d_stl.train_X = d_stl.train_X[:]
    d_stl.val_X = d_stl.val_X[:]
    d_stl.test_X = d_stl.test_X[:]
    d_stl.train_y = d_stl.train_y[:]
    d_stl.val_y = d_stl.val_y[:]
    d_stl.test_y = d_stl.test_y[:]

    # Remap class indices to match CIFAR-10:
    cls_mapping = np.array([0, 2, 1, 3, 4, 5, 6, -1, 7, 8])
    d_stl.train_y = cls_mapping[d_stl.train_y]
    d_stl.val_y = cls_mapping[d_stl.val_y]
    d_stl.test_y = cls_mapping[d_stl.test_y]

    d_stl.train_X = d_stl.train_X[:]
    d_stl.val_X = d_stl.val_X[:]
    d_stl.test_X = d_stl.test_X[:]

    # Remove all samples from class -1 (monkey) as it does not appear int the CIFAR-10 dataset
    train_mask = d_stl.train_y != -1
    val_mask = d_stl.val_y != -1
    test_mask = d_stl.test_y != -1

    d_stl.train_X = d_stl.train_X[train_mask]
    d_stl.train_y = d_stl.train_y[train_mask]
    d_stl.val_X = d_stl.val_X[val_mask]
    d_stl.val_y = d_stl.val_y[val_mask]
    d_stl.test_X = d_stl.test_X[test_mask]
    d_stl.test_y = d_stl.test_y[test_mask]

    # Downsample images from 96x96 to 32x32
    d_stl.train_X = downscale_local_mean(d_stl.train_X, (1, 1, 3, 3))
    d_stl.val_X = downscale_local_mean(d_stl.val_X, (1, 1, 3, 3))
    d_stl.test_X = downscale_local_mean(d_stl.test_X, (1, 1, 3, 3))

    if zero_centre:
        d_stl.train_X = d_stl.train_X * 2.0 - 1.0
        d_stl.val_X = d_stl.val_X * 2.0 - 1.0
        d_stl.test_X = d_stl.test_X * 2.0 - 1.0

    print(
        'STL: train: X.shape={}, y.shape={}, val: X.shape={}, y.shape={}, test: X.shape={}, y.shape={}'
        .format(d_stl.train_X.shape, d_stl.train_y.shape, d_stl.val_X.shape,
                d_stl.val_y.shape, d_stl.test_X.shape, d_stl.test_y.shape))

    print('STL: train: X.min={}, X.max={}'.format(d_stl.train_X.min(),
                                                  d_stl.train_X.max()))

    d_stl.n_classes = 9

    return d_stl