def _init_data(self): hparams = self.hparams batch_size = hparams.batch_size if hparams.dataset == 'mnist': # Get MNIST test data X_train, Y_train, X_test, Y_test = data_mnist( train_start=hparams.train_start, train_end=hparams.train_end, test_start=hparams.test_start, test_end=hparams.test_end) input_shape = (batch_size, 28, 28, 1) preproc_func = None elif hparams.dataset == 'cifar10': X_train, Y_train, X_test, Y_test = cifar_input.read_CIFAR10( os.path.join(hparams.data_path, hparams.dataset)) input_shape = (batch_size, 32, 32, 3) preproc_func = cifar_input.cifar_tf_preprocess elif hparams.dataset == 'svhn': X_train, Y_train, X_test, Y_test = svhn_input.read_SVHN( os.path.join(hparams.data_path, hparams.dataset)) input_shape = (batch_size, 32, 32, 3) preproc_func = svhn_input.svhn_tf_preprocess # Use label smoothing assert Y_train.shape[1] == 10. label_smooth = .1 Y_train = Y_train.clip(label_smooth / 9., 1. - label_smooth) self.X_train = X_train self.Y_train = Y_train self.X_test = X_test self.Y_test = Y_test self.data = (X_train, Y_train, X_test, Y_test) self.input_shape = input_shape self.preproc_func = preproc_func