def test_val_0(): from batchup.datasets import stl ds = stl.STL(n_val_folds=0) assert ds.train_X_u8.shape == (5000, 3, 96, 96) assert ds.train_X_u8.dtype == np.uint8 assert ds.train_y.shape == (5000, ) assert ds.train_y.dtype == np.int32 assert ds.val_X_u8.shape == (0, 3, 96, 96) assert ds.val_X_u8.dtype == np.uint8 assert ds.val_y.shape == (0, ) assert ds.val_y.dtype == np.int32 assert ds.test_X_u8.shape == (8000, 3, 96, 96) assert ds.test_X_u8.dtype == np.uint8 assert ds.test_y.shape == (8000, ) assert ds.test_y.dtype == np.int32 assert ds.class_names == [ 'airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck' ]
def test_train_test_split(): from batchup.datasets import stl from batchup.tests.dataset_test_helpers import sample_hashes ds = stl.STL(n_val_folds=1) train_h = sample_hashes(ds.train_X_u8) test_h = sample_hashes(ds.test_X_u8) assert set(train_h).intersection(set(test_h)) == set()
def load_stl(zero_centre=False, val=False): # # # Load STL for adaptation with CIFAR-10 # # print('Loading STL...') if val: d_stl = stl.STL() else: d_stl = stl.STL(n_val_folds=0) d_stl.train_X = d_stl.train_X[:] d_stl.val_X = d_stl.val_X[:] d_stl.test_X = d_stl.test_X[:] d_stl.train_y = d_stl.train_y[:] d_stl.val_y = d_stl.val_y[:] d_stl.test_y = d_stl.test_y[:] # Remap class indices to match CIFAR-10: cls_mapping = np.array([0, 2, 1, 3, 4, 5, 6, -1, 7, 8]) d_stl.train_y = cls_mapping[d_stl.train_y] d_stl.val_y = cls_mapping[d_stl.val_y] d_stl.test_y = cls_mapping[d_stl.test_y] d_stl.train_X = d_stl.train_X[:] d_stl.val_X = d_stl.val_X[:] d_stl.test_X = d_stl.test_X[:] # Remove all samples from class -1 (monkey) as it does not appear int the CIFAR-10 dataset train_mask = d_stl.train_y != -1 val_mask = d_stl.val_y != -1 test_mask = d_stl.test_y != -1 d_stl.train_X = d_stl.train_X[train_mask] d_stl.train_y = d_stl.train_y[train_mask] d_stl.val_X = d_stl.val_X[val_mask] d_stl.val_y = d_stl.val_y[val_mask] d_stl.test_X = d_stl.test_X[test_mask] d_stl.test_y = d_stl.test_y[test_mask] # Downsample images from 96x96 to 32x32 d_stl.train_X = downscale_local_mean(d_stl.train_X, (1, 1, 3, 3)) d_stl.val_X = downscale_local_mean(d_stl.val_X, (1, 1, 3, 3)) d_stl.test_X = downscale_local_mean(d_stl.test_X, (1, 1, 3, 3)) if zero_centre: d_stl.train_X = d_stl.train_X * 2.0 - 1.0 d_stl.val_X = d_stl.val_X * 2.0 - 1.0 d_stl.test_X = d_stl.test_X * 2.0 - 1.0 print( 'STL: train: X.shape={}, y.shape={}, val: X.shape={}, y.shape={}, test: X.shape={}, y.shape={}' .format(d_stl.train_X.shape, d_stl.train_y.shape, d_stl.val_X.shape, d_stl.val_y.shape, d_stl.test_X.shape, d_stl.test_y.shape)) print('STL: train: X.min={}, X.max={}'.format(d_stl.train_X.min(), d_stl.train_X.max())) d_stl.n_classes = 9 return d_stl