def test_train_test_split(): from batchup.datasets import mnist from batchup.tests.dataset_test_helpers import sample_hashes ds = mnist.MNIST(n_val=0) train_h = sample_hashes(ds.train_X_u8) test_h = sample_hashes(ds.test_X_u8) assert set(train_h).intersection(set(test_h)) == set()
def test_val_0(): from batchup.datasets import mnist ds = mnist.MNIST(n_val=0) assert ds.train_X.shape == (60000, 1, 28, 28) assert ds.train_X.dtype == np.float32 assert ds.train_y.shape == (60000, ) assert ds.train_y.dtype == np.int32 assert ds.val_X.shape == (0, 1, 28, 28) assert ds.val_X.dtype == np.float32 assert ds.val_y.shape == (0, ) assert ds.val_y.dtype == np.int32 assert ds.test_X.shape == (10000, 1, 28, 28) assert ds.test_X.dtype == np.float32 assert ds.test_y.shape == (10000, ) assert ds.test_y.dtype == np.int32
def load_mnistm(datadir, val=False, zero_centre=False): mnistm_fname = os.path.join(datadir, 'mnistm.pkl') if not os.path.isfile(mnistm_fname): bsds = 'BSDS500' bsds = os.path.join(datadir, bsds) if not os.path.isfile(bsds): print('BSDS downloads....') os.system('git clone https://github.com/BIDS/BSDS500.git ' + bsds) d_mnist = mnist.MNIST(n_val=0) d_mnist.train_X = d_mnist.train_X[:] d_mnist.test_X = d_mnist.test_X[:] d_mnist.train_y = d_mnist.train_y[:] d_mnist.test_y = d_mnist.test_y[:] recur_names = [] for filename in glob.iglob(os.path.join(datadir, '**', '*'), recursive=True): recur_names.append(filename) train_files = [] for name in recur_names: if name.startswith(os.path.join(bsds, 'BSDS500/data/images/train/')): train_files.append(name) print(len(train_files)) print("Loading BSR training images") background_data = [] for name in train_files: try: bg_img = skimage.io.imread(name) background_data.append(bg_img) except: continue print(np.max(d_mnist.train_X)) print(np.min(d_mnist.train_X)) print(d_mnist.train_X.dtype) print(d_mnist.train_X.shape) print(len(background_data)) print("Building train set...") train = create_mnistm(d_mnist.train_X, background_data) print("Building test set...") test = create_mnistm(d_mnist.test_X, background_data) print("Building validation set...") valid = create_mnistm(d_mnist.val_X, background_data) with open(mnistm_fname, 'wb') as f: pkl.dump( { 'train': [train, d_mnist.train_y], 'test': [test, d_mnist.test_y], 'valid': [valid, d_mnist.val_y] }, f, pkl.HIGHEST_PROTOCOL) with open(mnistm_fname, 'rb') as f: x = pkl.load(f) x['train'][0] = x['train'][0].astype('float32') / 255. x['test'][0] = x['test'][0].astype('float32') / 255. if zero_centre: x['train'][0] = x['train'][0] * 2.0 - 1.0 x['test'][0] = x['test'][0] * 2.0 - 1.0 if val: mnistm_val = [x['train'][0][:10000], x['train'][1][:10000]] mnistm_train = [x['train'][0][10000:], x['train'][1][10000:]] return mnistm_train, x['test'], mnistm_val else: return x['train'], x['test'], x['valid']
def load_mnist(invert=False, zero_centre=False, intensity_scale=1.0, val=False, pad32=False, downscale_x=1, rgb=False): # # # Load MNIST # # print('Loading MNIST...') if val: d_mnist = mnist.MNIST(n_val=10000) else: d_mnist = mnist.MNIST(n_val=0) d_mnist.train_X = d_mnist.train_X[:] d_mnist.val_X = d_mnist.val_X[:] d_mnist.test_X = d_mnist.test_X[:] d_mnist.train_y = d_mnist.train_y[:] d_mnist.val_y = d_mnist.val_y[:] d_mnist.test_y = d_mnist.test_y[:] if downscale_x != 1: d_mnist.train_X = downscale_local_mean(d_mnist.train_X, (1, 1, 1, downscale_x)) d_mnist.val_X = downscale_local_mean(d_mnist.val_X, (1, 1, 1, downscale_x)) d_mnist.test_X = downscale_local_mean(d_mnist.test_X, (1, 1, 1, downscale_x)) if pad32: py = (32 - d_mnist.train_X.shape[2]) // 2 px = (32 - d_mnist.train_X.shape[3]) // 2 # Pad 28x28 to 32x32 d_mnist.train_X = np.pad(d_mnist.train_X, [(0, 0), (0, 0), (py, py), (px, px)], mode='constant') d_mnist.val_X = np.pad(d_mnist.val_X, [(0, 0), (0, 0), (py, py), (px, px)], mode='constant') d_mnist.test_X = np.pad(d_mnist.test_X, [(0, 0), (0, 0), (py, py), (px, px)], mode='constant') if invert: # Invert d_mnist.train_X = 1.0 - d_mnist.train_X d_mnist.val_X = 1.0 - d_mnist.val_X d_mnist.test_X = 1.0 - d_mnist.test_X if intensity_scale != 1.0: d_mnist.train_X = (d_mnist.train_X - 0.5) * intensity_scale + 0.5 d_mnist.val_X = (d_mnist.val_X - 0.5) * intensity_scale + 0.5 d_mnist.test_X = (d_mnist.test_X - 0.5) * intensity_scale + 0.5 if zero_centre: d_mnist.train_X = d_mnist.train_X * 2.0 - 1.0 d_mnist.test_X = d_mnist.test_X * 2.0 - 1.0 if rgb: d_mnist.train_X = np.concatenate([d_mnist.train_X] * 3, axis=1) d_mnist.val_X = np.concatenate([d_mnist.val_X] * 3, axis=1) d_mnist.test_X = np.concatenate([d_mnist.test_X] * 3, axis=1) print( 'MNIST: train: X.shape={}, y.shape={}, val: X.shape={}, y.shape={}, test: X.shape={}, y.shape={}' .format(d_mnist.train_X.shape, d_mnist.train_y.shape, d_mnist.val_X.shape, d_mnist.val_y.shape, d_mnist.test_X.shape, d_mnist.test_y.shape)) print('MNIST: train: X.min={}, X.max={}'.format(d_mnist.train_X.min(), d_mnist.train_X.max())) d_mnist.n_classes = 10 return d_mnist