def load_reduced(ds_name, validation=None, data_dir=None, force_refresh=False): dataset_dir = get_dataset_dir('babble', data_dir=data_dir) red_path = os.path.join(dataset_dir, 'reduced_{}.npz'.format(ds_name)) if not os.path.exists(red_path) or force_refresh: print('Building reduced {}.'.format(ds_name)) x_path = os.path.join(dataset_dir, 'reduced_{}_x.npz'.format(ds_name)) assert os.path.exists(x_path) x = np.load(x_path, allow_pickle=True)['reduced_x'] x = [np.array(mat.tolist().todense()) for mat in x] data = np.load(os.path.join(dataset_dir, 'cdrlabels.npz'), allow_pickle=True) labels = data['labels'] np.savez(red_path, labels=labels, x=x) else: print('Retrieving reduced {}.'.format(ds_name)) data = np.load(red_path, allow_pickle=True) x = data['x'] labels = data['labels'] train = DataSet(x[0], labels[0]) validation = DataSet(x[1], labels[1]) test = DataSet(x[2], labels[2]) print('Loaded reduced {}.'.format(ds_name)) return base.Datasets(train=train, validation=validation, test=test)
def load_small_mnist(validation_size=5000, random_seed=0, data_dir=None): dataset_dir = get_dataset_dir('mnist', data_dir=data_dir) mnist_small_file = 'mnist_small_val-{}_seed-{}.npz'.format( validation_size, random_seed) mnist_small_path = os.path.join(dataset_dir, mnist_small_file) if not os.path.exists(mnist_small_path): rng = np.random.RandomState(seed=random_seed) data_sets = load_mnist(validation_size, data_dir=data_dir) train_images = data_sets.train.x train_labels = data_sets.train.labels perm = np.arange(len(train_labels)) rng.shuffle(perm) num_to_keep = int(len(train_labels) / 10) perm = perm[:num_to_keep] train_images = train_images[perm, :] train_labels = train_labels[perm] validation_images = data_sets.validation.x validation_labels = data_sets.validation.labels # perm = np.arange(len(validation_labels)) # rng.shuffle(perm) # num_to_keep = int(len(validation_labels) / 10) # perm = perm[:num_to_keep] # validation_images = validation_images[perm, :] # validation_labels = validation_labels[perm] test_images = data_sets.test.x test_labels = data_sets.test.labels # perm = np.arange(len(test_labels)) # rng.shuffle(perm) # num_to_keep = int(len(test_labels) / 10) # perm = perm[:num_to_keep] # test_images = test_images[perm, :] # test_labels = test_labels[perm] np.savez(mnist_small_path, train_images=train_images, train_labels=train_labels, validation_images=validation_images, validation_labels=validation_labels, test_images=test_images, test_labels=test_labels) else: data = np.load(mnist_small_path) train_images = data['train_images'] train_labels = data['train_labels'] validation_images = data['validation_images'] validation_labels = data['validation_labels'] test_images = data['test_images'] test_labels = data['test_labels'] train = DataSet(train_images, train_labels) validation = DataSet(validation_images, validation_labels) test = DataSet(test_images, test_labels) return base.Datasets(train=train, validation=validation, test=test)
def load_reduced_nonfires(ds_name, source_url, data_dir=None, force_refresh=False): dataset_dir = get_dataset_dir('babble', data_dir=data_dir) red_path = os.path.join(dataset_dir, 'reduced_{}_nonfires.npz'.format(ds_name)) if not os.path.exists(red_path) or force_refresh: print('Building reduced {} nonfires.'.format(ds_name)) x_path = os.path.join(dataset_dir, 'reduced_{}_nonfires_x.npz'.format(ds_name)) assert os.path.exists(x_path) x = np.load(x_path, allow_pickle=True)['reduced_x'][0].tolist().todense() data = np.load(os.path.join(dataset_dir, 'cdr_nonfireslabels.npz'), allow_pickle=True) labels = data['labels'][0] np.savez(red_path, labels=labels, x=np.array(x)) else: print('Retrieving reduced {} nonfires.'.format(ds_name)) data = np.load(red_path, allow_pickle=True) x = data['x'] labels = data['labels'] red_nonfires = DataSet(x, labels) print('Loaded reduced {} nonfires.'.format(ds_name)) return red_nonfires
def compute_test_infl(self): num_subsets = len(self.R['subset_indices']) subsets_per_batch = 256 results = self.task_queue.execute( 'compute_cex_test_infl_batch', [(i, min(i + subsets_per_batch, num_subsets)) for i in range(0, num_subsets, subsets_per_batch)], force_refresh=True) res = self.task_queue.collate_results(results) ds_test = DataSet(self.R['cex_X'], self.R['cex_Y']) model = self.get_model() model.load('initial') test_grad_losses = model.get_indiv_grad_loss(ds_test, verbose=False) test_grad_margins = model.get_indiv_grad_margin(ds_test, verbose=False) pred_dparam = self.R['subset_pred_dparam'] newton_pred_dparam = self.R['subset_newton_pred_dparam'] res['cex_subset_test_pred_infl'] = np.dot(pred_dparam, test_grad_losses.T) res['cex_subset_test_pred_margin_infl'] = np.dot( pred_dparam, test_grad_margins.T) res['cex_subset_test_newton_pred_infl'] = np.dot( newton_pred_dparam, test_grad_losses.T) res['cex_subset_test_newton_pred_margin_infl'] = np.dot( newton_pred_dparam, test_grad_margins.T) return res
def load_mnist(validation_size=5000, data_dir=None): SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' dataset_dir = get_dataset_dir('mnist', data_dir=data_dir) local_files = [ maybe_download(SOURCE_URL + image_file, image_file, dataset_dir) for image_file in (TRAIN_IMAGES, TRAIN_LABELS, TEST_IMAGES, TEST_LABELS) ] with open(local_files[0], 'rb') as f: train_images = extract_images(f) with open(local_files[1], 'rb') as f: train_labels = extract_labels(f) with open(local_files[2], 'rb') as f: test_images = extract_images(f) with open(local_files[3], 'rb') as f: test_labels = extract_labels(f) if not 0 <= validation_size <= len(train_images): raise ValueError( 'Validation size should be between 0 and {}. Received: {}.'.format( len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] train_images = train_images.astype(np.float32) / 255 validation_images = validation_images.astype(np.float32) / 255 test_images = test_images.astype(np.float32) / 255 train = DataSet(train_images, train_labels) validation = DataSet(validation_images, validation_labels) test = DataSet(test_images, test_labels) return base.Datasets(train=train, validation=validation, test=test)
def get_indiv_grad_loss_from_total_grad(self, dataset, **kwargs): indiv_grad_loss = self.batch_evaluate(lambda xs, labels: [ self.get_total_grad_loss(DataSet(xs, labels), l2_reg=0, **kwargs) ], lambda v1, v2: v1.extend(v2) or v1, 1, dataset, value_name="Gradients") return np.array(indiv_grad_loss)
def load_spam(truncate=None, data_dir=None): dataset_dir = get_dataset_dir('spam', data_dir=data_dir) spam_path = os.path.join(dataset_dir, 'spam_truncate-{}.npz'.format(truncate)) if not os.path.exists(spam_path): SPAM_URL = "http://www.aueb.gr/users/ion/data/enron-spam/preprocessed/enron1.tar.gz" raw_spam_path = maybe_download(SPAM_URL, 'enron1.tar.gz', dataset_dir) print("Extracting {}".format(raw_spam_path)) with tarfile.open(raw_spam_path, 'r:gz') as tarf: tarf.extractall(path=dataset_dir) print("Processing spam") X_train, Y_train, X_valid, Y_valid, X_test, Y_test = process_spam( dataset_dir, truncate) # Convert them to dense matrices X_train = X_train.toarray() X_valid = X_valid.toarray() X_test = X_test.toarray() np.savez(spam_path, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test, X_valid=X_valid, Y_valid=Y_valid) else: data = np.load(spam_path) X_train = data['X_train'] Y_train = data['Y_train'] X_test = data['X_test'] Y_test = data['Y_test'] X_valid = data['X_valid'] Y_valid = data['Y_valid'] train = DataSet(X_train, Y_train) validation = DataSet(X_valid, Y_valid) test = DataSet(X_test, Y_test) return base.Datasets(train=train, validation=validation, test=test)
def get_dataset(self, dataset_id=None): dataset_id = dataset_id if dataset_id is not None else self.dataset_id if not hasattr(self, 'datasets'): self.datasets = dict() if not dataset_id in self.datasets: ds_keys = [ '{}_{}'.format(dataset_id, key) for key in ('train_X', 'train_Y', 'test_X', 'test_Y') ] if any(ds_key not in self.R for ds_key in ds_keys): raise ValueError('Dataset gauss has not been generated') train_X, train_Y, test_X, test_Y = [ self.R[ds_key] for ds_key in ds_keys ] train = DataSet(train_X, train_Y) test = DataSet(test_X, test_Y) self.datasets[dataset_id] = base.Datasets(train=train, test=test, validation=None) return self.datasets[dataset_id]
def load_animals(data_dir=None): dataset_dir = get_dataset_dir('processed_animals', data_dir=data_dir) BASE_URL = "http://mitra.stanford.edu/kundaje/pangwei/" TRAIN_FILE_NAME = "animals_900_300_inception_features_train.npz" TEST_FILE_NAME = "animals_900_300_inception_features_test.npz" train_path = maybe_download(BASE_URL + TRAIN_FILE_NAME, TRAIN_FILE_NAME, dataset_dir) test_path = maybe_download(BASE_URL + TEST_FILE_NAME, TEST_FILE_NAME, dataset_dir) data_train = np.load(train_path) data_test = np.load(test_path) X_train = data_train['inception_features_val'] Y_train = data_train['labels'].astype(np.uint8) X_test = data_test['inception_features_val'] Y_test = data_test['labels'].astype(np.uint8) train = DataSet(X_train, Y_train) test = DataSet(X_test, Y_test) return base.Datasets(train=train, validation=None, test=test)
def load_mnli(data_dir=None, non_tf=False): dataset_dir = get_dataset_dir('multinli_1.0', data_dir=data_dir) path = os.path.join(dataset_dir, 'mnli.npz') print('Loading mnli from {}.'.format(path)) data = np.load(path, allow_pickle=True) labels = [data['lab0'],[],data['lab2']] x = [data['x0'],[],data['x2']] mask = range(600) x[0] = np.array(x[0])[:, mask] x[1] = np.array([],dtype=np.float32) labels[1] = np.array([],dtype=np.float32) x[2] = np.array(x[2])[:, mask] train = DataSet(x[0], labels[0]) validation = DataSet(x[1], labels[1]) test = DataSet(x[2], labels[2]) print('Loaded mnli.') if non_tf: return (train, validation, test) return base.Datasets(train=train, validation=validation, test=test)
def load_small_cifar10(validation_size=1000, random_seed=0, data_dir=None): dataset_dir = get_dataset_dir('cifar10', data_dir=data_dir) data_sets = load_cifar10(validation_size, data_dir=data_dir) rng = np.random.RandomState(random_seed) train_images = data_sets.train.x train_labels = data_sets.train.labels perm = np.arange(len(train_labels)) rng.shuffle(perm) num_to_keep = int(len(train_labels)/10) perm = perm[:num_to_keep] train_images = train_images[perm,:] train_labels = train_labels[perm] validation_images = data_sets.validation.x validation_labels = data_sets.validation.labels test_images = data_sets.test.x test_labels = data_sets.test.labels train = DataSet(train_images, train_labels) validation = DataSet(validation_images, validation_labels) test = DataSet(test_images, test_labels) return base.Datasets(train=train, validation=validation, test=test)
def load_mnli_nonfires(data_dir=None): dataset_dir = get_dataset_dir('multinli_1.0', data_dir=data_dir) path = os.path.join(dataset_dir, 'mnli_nonfires.npz') print('Loading nonfires from {}.'.format(path)) data = np.load(path, allow_pickle=True) x = np.array(data['x']) labels = np.array(data['labels']) mask = range(600) x = x[:, mask] nonfires = DataSet(x, labels) print('Loaded mnli nonfires.') return nonfires
def compute_cex_test_infl_batch(self, subset_start, subset_end): self.load_phases([0, 1, 2, 3, 4], verbose=False) start_time = time.time() ds_test = DataSet(self.R['cex_X'], self.R['cex_Y']) res = dict(('cex_' + key, value) for key, value in self.compute_test_influence( subset_start, subset_end, ds_test).items()) end_time = time.time() time_per_subset = (end_time - start_time) / (subset_end - subset_start) remaining_time = (len(self.R['subset_indices']) - subset_end) * time_per_subset print('Each subset takes {} s, {} s remaining'.format( time_per_subset, remaining_time)) return res
def load_hospital(data_dir=None): dataset_dir = get_dataset_dir('hospital', data_dir=data_dir) hospital_path = os.path.join(dataset_dir, 'hospital.npz') if not os.path.exists(hospital_path): HOSPITAL_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00296/dataset_diabetes.zip" raw_path = maybe_download(HOSPITAL_URL, 'dataset_diabetes.zip', dataset_dir) with zipfile.ZipFile(raw_path, 'r') as zipf: zipf.extractall(path=dataset_dir) csv_path = os.path.join(dataset_dir, 'dataset_diabetes', 'diabetic_data.csv') df = pd.read_csv(csv_path) # Convert categorical variables into numeric ones rng = np.random.RandomState(2) X = pd.DataFrame() # Numerical variables that we can pull directly X = df.loc[:, [ 'time_in_hospital', 'num_lab_procedures', 'num_procedures', 'num_medications', 'number_outpatient', 'number_emergency', 'number_inpatient', 'number_diagnoses' ]] categorical_var_names = [ 'gender', 'race', 'age', 'discharge_disposition_id', 'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide', 'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone', 'tolazamide', 'examide', 'citoglipton', 'insulin', 'glyburide-metformin', 'glipizide-metformin', 'glimepiride-pioglitazone', 'metformin-rosiglitazone', 'metformin-pioglitazone', 'change', 'diabetesMed' ] for categorical_var_name in categorical_var_names: categorical_var = pd.Categorical(df.loc[:, categorical_var_name]) # Just have one dummy variable if it's boolean if len(categorical_var.categories) == 2: drop_first = True else: drop_first = False dummies = pd.get_dummies(categorical_var, prefix=categorical_var_name, drop_first=drop_first) X = pd.concat([X, dummies], axis=1) ### Set the Y labels readmitted = pd.Categorical(df.readmitted) Y = np.copy(readmitted.codes) # Combine >30 and 0 and flip labels, so 1 (>30) and 2 (No) become -1, while 0 becomes 1 Y[Y >= 1] = -1 Y[Y == 0] = 1 # Map to feature names feature_names = X.columns.values ### Find indices of age features age_var = pd.Categorical(df.loc[:, 'age']) age_var_names = [ 'age_%s' % age_var_name for age_var_name in age_var.categories ] age_var_indices = [] for age_var_name in age_var_names: age_var_indices.append( np.where(X.columns.values == age_var_name)[0][0]) age_var_indices = np.array(age_var_indices, dtype=int) ### Split into training and test sets. # For convenience, we balance the training set to have 10k positives and 10k negatives. num_examples = len(Y) assert X.shape[0] == num_examples num_train_examples = 20000 num_train_examples_per_class = int(num_train_examples / 2) num_test_examples = num_examples - num_train_examples assert num_test_examples > 0 pos_idx = np.where(Y == 1)[0] neg_idx = np.where(Y == -1)[0] rng.shuffle(pos_idx) rng.shuffle(neg_idx) assert len(pos_idx) + len(neg_idx) == num_examples train_idx = np.concatenate((pos_idx[:num_train_examples_per_class], neg_idx[:num_train_examples_per_class])) test_idx = np.concatenate((pos_idx[num_train_examples_per_class:], neg_idx[num_train_examples_per_class:])) rng.shuffle(train_idx) rng.shuffle(test_idx) X_train = np.array(X.iloc[train_idx, :], dtype=np.float32) Y_train = Y[train_idx] X_test = np.array(X.iloc[test_idx, :], dtype=np.float32) Y_test = Y[test_idx] lr_Y_train = np.array((Y_train + 1) / 2, dtype=int) lr_Y_test = np.array((Y_test + 1) / 2, dtype=int) #test_children_idx = np.where(X_test[:, age_var_indices[0]] == 1)[0] np.savez(hospital_path, X_train=X_train, Y_train=Y_train, lr_Y_train=lr_Y_train, X_test=X_test, Y_test=Y_test, lr_Y_test=lr_Y_test) else: data = np.load(hospital_path) X_train = data['X_train'] Y_train = data['Y_train'] lr_Y_train = data['lr_Y_train'] X_test = data['X_test'] Y_test = data['Y_test'] lr_Y_test = data['lr_Y_test'] train = DataSet(X_train, Y_train) validation = None test = DataSet(X_test, Y_test) data_sets = base.Datasets(train=train, validation=validation, test=test) lr_train = DataSet(X_train, lr_Y_train) lr_validation = None lr_test = DataSet(X_test, lr_Y_test) lr_data_sets = base.Datasets(train=lr_train, validation=lr_validation, test=lr_test) return lr_data_sets
def load_nonfires(ds_name, source_url, data_dir=None, force_refresh=False): dataset_dir = get_dataset_dir('babble', data_dir=data_dir) path = os.path.join(dataset_dir, '{}_nonfires'.format(ds_name)) if not os.path.exists(path + 'labels.npz') or force_refresh: raw_path = maybe_download(source_url, '{}.db'.format(ds_name), dataset_dir) print('Extracting {}.'.format(raw_path)) conn = sqlite3.connect(raw_path) # ex_id = candidate_id LF_labels = pd.read_sql_query("select * from label;", conn) # label, ex_id, LF_id splits = pd.read_sql_query("select id, split from candidate;", conn) #ex_id, train/dev/test split (0-2) features = pd.read_sql_query("select * from feature;", conn) # feature value, ex_id, feature_id gold_labels = pd.read_sql_query( "select value, candidate_id from gold_label;", conn) # gold, ex_id # all values are probably 1.0 conn.close() split_ids = [splits['id'][splits['split'] == i] for i in range(3)] # ex_ids in each split ids_dups = np.array(LF_labels['candidate_id']) # ex_id for each LF_ex which_split = [np.isin(ids_dups, split_ids[i]) for i in range(3)] # which LF_ex in each split print('Extracted labels.') # Look through all LF labels and record seen candidate ids # Look through all candidates and find the non-seen ones # Collect the features of those _num_features = max(features['key_id']) print('Finding nonfires.') LF_labeled = set() for _, ex_id, _ in tqdm(np.array(LF_labels)): LF_labeled.add(ex_id) print('Collecting labels.') labels = [] ex_id_to_ind = {} count = 0 for lab, ex_id in tqdm(np.array(gold_labels)): if ex_id not in LF_labeled: labels.append((lab + 1) / 2) ex_id_to_ind[ex_id] = count count += 1 labels = np.array(labels) _num_ex = count print('Collecting features.') x = dok_matrix((_num_ex, _num_features)) for val, ex_id, f_id in tqdm(np.array(features)): if ex_id not in LF_labeled: x[ex_id_to_ind[ex_id], f_id - 1] = val x = x.astype(np.float32) save_npz(path + 'x.npz', coo_matrix(x)) np.savez(path + 'labels.npz', labels=[labels]) else: print('Loading {} nonfires.'.format(ds_name)) data = np.load(path + 'labels.npz', allow_pickle=True) labels = data['labels'][0] x = load_npz(path + 'x.npz') nonfires = DataSet(x, labels) print('Loaded {} nonfires.'.format(ds_name)) return nonfires
def load(ds_name, source_url, validation_size=None, data_dir=None, force_refresh=False): dataset_dir = get_dataset_dir('babble', data_dir=data_dir) path = os.path.join(dataset_dir, '{}'.format(ds_name)) if not os.path.exists(path + 'labels.npz') or force_refresh: raw_path = maybe_download(source_url, '{}.db'.format(ds_name), dataset_dir) print('Extracting {}.'.format(raw_path)) conn = sqlite3.connect(raw_path) # ex_id = candidate_id LF_labels = pd.read_sql_query("select * from label;", conn) # label, ex_id, LF_id features = pd.read_sql_query("select * from feature;", conn) # feature value, ex_id, feature_id # all values are probably 1.0 splits = pd.read_sql_query("select id, split from candidate;", conn) #ex_id, train/dev/test split (0-2) start_test_ind = np.min( splits['id'][splits['split'] == 2]) # not the 1-indexing test_gold_labels = pd.read_sql_query( "select value, candidate_id from gold_label where candidate_id>{} order by candidate_id asc;" .format(start_test_ind - 1), conn) # gold, ex_id conn.close() split_ids = [splits['id'][splits['split'] == i] for i in range(3)] # ex_ids in each split ids_dups = np.array(LF_labels['candidate_id']) # ex_id for each LF_ex which_split = [np.isin(ids_dups, split_ids[i]) for i in range(3)] # which LF_ex in each split labels = [ np.array(LF_labels['value'][which_split[i]]) for i in range(3) ] # label for each LF_ex print('Extracted labels.') _num_features = max(features['key_id']) _num_ex = splits.shape[0] _num_LF_ex = [labels[i].shape[0] for i in range(3)] print('Creating map from examples to sparse features.') last_seen = features['candidate_id'][0] count = 0 ex_id_to_ind = { last_seen: count } # in case the ex_id aren't consecutive ind_to_features = [dict() for i in range(_num_ex + 1)] for val, ex_id, key_id in tqdm(np.array(features)): if ex_id > last_seen: count += 1 last_seen = ex_id ex_id_to_ind[last_seen] = count ind_to_features[count][key_id] = val print('Creating sparse feature matrices for LF examples.') x = [dok_matrix((_num_LF_ex[i], _num_features)) for i in range(3)] counts = [0 for i in range(3)] for _, ex_id, _ in tqdm(np.array(LF_labels)): split = splits['split'][ex_id_to_ind[ex_id]] for key_id, val in ind_to_features[ ex_id_to_ind[ex_id]].iteritems(): x[split][counts[split], key_id - 1] = val counts[split] += 1 print('Extracted feature matrices.') print('Reverting test things to gold.') _num_test = sum(splits['split'] == 2) x[2] = dok_matrix((_num_test, _num_features)) labels[2] = np.array(test_gold_labels['value']) count = 0 for ex_id, split in tqdm(np.array(splits)): if split == 2: for key_id, val in ind_to_features[ ex_id_to_ind[ex_id]].iteritems(): x[2][count, key_id - 1] = val count += 1 labels = [((labels[i] + 1) / 2) for i in range(3)] # convert (-1,1) to (0,1) labels = [labels[i] for i in range(3)] for i in range(3): save_npz(path + 'x{}.npz'.format(i), coo_matrix(x[i])) np.savez(path + 'labels.npz', labels=labels) else: print('Loading {}.'.format(ds_name)) data = np.load(path + 'labels.npz', allow_pickle=True) labels = data['labels'] x = [] for i in range(3): x.append(load_npz(path + 'x{}.npz'.format(i))) train = DataSet(x[0], labels[0]) validation = DataSet(x[1], labels[1]) test = DataSet(x[2], labels[2]) print('Loaded {}.'.format(ds_name)) return base.Datasets(train=train, validation=validation, test=test)
def load_cifar10(validation_size=1000, data_dir=None): dataset_dir = get_dataset_dir('cifar10', data_dir=data_dir) cifar10_path = os.path.join(dataset_dir, 'cifar10.npz') img_size = 32 num_channels = 3 img_size_flat = img_size * img_size * num_channels num_classes = 10 _num_files_train = 5 _images_per_file = 10000 _num_images_train = _num_files_train * _images_per_file if not os.path.exists(cifar10_path): def _unpickle(f): f = os.path.join(dataset_dir, 'cifar-10-batches-py', f) print("Loading data: " + f) with open(f, 'rb') as fo: data_dict = pickle.load(fo) return data_dict def _convert_images(raw): raw_float = np.array(raw, dtype=float) / 255.0 images = raw_float.reshape([-1, num_channels, img_size, img_size]) images = images.transpose([0,2,3,1]) return images def _load_data(f): data = _unpickle(f) raw = data[b'data'] labels = np.array(data[b'labels']) images = _convert_images(raw) return images, labels def _load_class_names(): raw = _unpickle(f='batches.meta')[b'label_names'] return [x.decode('utf-8') for x in raw] def _load_training_data(): images = np.zeros(shape=[_num_images_train, img_size, img_size, num_channels], dtype=float) labels = np.zeros(shape=[_num_images_train], dtype=int) begin = 0 for i in range(_num_files_train): images_batch, labels_batch = _load_data(f='data_batch_'+str(i+1)) end = begin + len(images_batch) images[begin:end,:] = images_batch labels[begin:end] = labels_batch begin = end return images, labels, _one_hot_encoded(class_numbers=labels, num_classes=num_classes) def _load_test_data(): images, labels = _load_data(f='test_batch') return images, labels, _one_hot_encoded(class_numbers=labels, num_classes=num_classes) SOURCE_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" raw_cifar10_path = maybe_download(SOURCE_URL, 'cifar-10-python.tar.gz', dataset_dir) print('Extracting {}.'.format(raw_cifar10_path)) with tarfile.open(raw_cifar10_path, 'r:gz') as tarf: tarf.extractall(path=dataset_dir) # no one-hot encoding of labels train_images, train_labels, _ = _load_training_data() test_images, test_labels, _ = _load_test_data() names = _load_class_names() validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] np.savez(cifar10_path, train_images=train_images, train_labels=train_labels, validation_images=validation_images, validation_labels=validation_labels, test_images=test_images, test_labels=test_labels) else: data = np.load(cifar10_path) train_images = data['train_images'] train_labels = data['train_labels'] validation_images = data['validation_images'] validation_labels = data['validation_labels'] test_images = data['test_images'] test_labels = data['test_labels'] train = DataSet(train_images, train_labels) validation = DataSet(validation_images, validation_labels) test = DataSet(test_images, test_labels) return base.Datasets(train=train, validation=validation, test=test)