Example #1
0
def load_reduced(ds_name, validation=None, data_dir=None, force_refresh=False):
    dataset_dir = get_dataset_dir('babble', data_dir=data_dir)
    red_path = os.path.join(dataset_dir, 'reduced_{}.npz'.format(ds_name))

    if not os.path.exists(red_path) or force_refresh:
        print('Building reduced {}.'.format(ds_name))
        x_path = os.path.join(dataset_dir, 'reduced_{}_x.npz'.format(ds_name))
        assert os.path.exists(x_path)

        x = np.load(x_path, allow_pickle=True)['reduced_x']
        x = [np.array(mat.tolist().todense()) for mat in x]
        data = np.load(os.path.join(dataset_dir, 'cdrlabels.npz'),
                       allow_pickle=True)
        labels = data['labels']
        np.savez(red_path, labels=labels, x=x)
    else:
        print('Retrieving reduced {}.'.format(ds_name))
        data = np.load(red_path, allow_pickle=True)
        x = data['x']
        labels = data['labels']

    train = DataSet(x[0], labels[0])
    validation = DataSet(x[1], labels[1])
    test = DataSet(x[2], labels[2])

    print('Loaded reduced {}.'.format(ds_name))
    return base.Datasets(train=train, validation=validation, test=test)
Example #2
0
def load_small_mnist(validation_size=5000, random_seed=0, data_dir=None):
    dataset_dir = get_dataset_dir('mnist', data_dir=data_dir)
    mnist_small_file = 'mnist_small_val-{}_seed-{}.npz'.format(
        validation_size, random_seed)
    mnist_small_path = os.path.join(dataset_dir, mnist_small_file)

    if not os.path.exists(mnist_small_path):
        rng = np.random.RandomState(seed=random_seed)
        data_sets = load_mnist(validation_size, data_dir=data_dir)

        train_images = data_sets.train.x
        train_labels = data_sets.train.labels
        perm = np.arange(len(train_labels))
        rng.shuffle(perm)
        num_to_keep = int(len(train_labels) / 10)
        perm = perm[:num_to_keep]
        train_images = train_images[perm, :]
        train_labels = train_labels[perm]

        validation_images = data_sets.validation.x
        validation_labels = data_sets.validation.labels
        # perm = np.arange(len(validation_labels))
        # rng.shuffle(perm)
        # num_to_keep = int(len(validation_labels) / 10)
        # perm = perm[:num_to_keep]
        # validation_images = validation_images[perm, :]
        # validation_labels = validation_labels[perm]

        test_images = data_sets.test.x
        test_labels = data_sets.test.labels
        # perm = np.arange(len(test_labels))
        # rng.shuffle(perm)
        # num_to_keep = int(len(test_labels) / 10)
        # perm = perm[:num_to_keep]
        # test_images = test_images[perm, :]
        # test_labels = test_labels[perm]

        np.savez(mnist_small_path,
                 train_images=train_images,
                 train_labels=train_labels,
                 validation_images=validation_images,
                 validation_labels=validation_labels,
                 test_images=test_images,
                 test_labels=test_labels)
    else:
        data = np.load(mnist_small_path)
        train_images = data['train_images']
        train_labels = data['train_labels']
        validation_images = data['validation_images']
        validation_labels = data['validation_labels']
        test_images = data['test_images']
        test_labels = data['test_labels']

    train = DataSet(train_images, train_labels)
    validation = DataSet(validation_images, validation_labels)
    test = DataSet(test_images, test_labels)

    return base.Datasets(train=train, validation=validation, test=test)
Example #3
0
def load_reduced_nonfires(ds_name,
                          source_url,
                          data_dir=None,
                          force_refresh=False):
    dataset_dir = get_dataset_dir('babble', data_dir=data_dir)
    red_path = os.path.join(dataset_dir,
                            'reduced_{}_nonfires.npz'.format(ds_name))

    if not os.path.exists(red_path) or force_refresh:
        print('Building reduced {} nonfires.'.format(ds_name))
        x_path = os.path.join(dataset_dir,
                              'reduced_{}_nonfires_x.npz'.format(ds_name))
        assert os.path.exists(x_path)

        x = np.load(x_path,
                    allow_pickle=True)['reduced_x'][0].tolist().todense()
        data = np.load(os.path.join(dataset_dir, 'cdr_nonfireslabels.npz'),
                       allow_pickle=True)
        labels = data['labels'][0]

        np.savez(red_path, labels=labels, x=np.array(x))
    else:
        print('Retrieving reduced {} nonfires.'.format(ds_name))
        data = np.load(red_path, allow_pickle=True)
        x = data['x']
        labels = data['labels']

    red_nonfires = DataSet(x, labels)

    print('Loaded reduced {} nonfires.'.format(ds_name))
    return red_nonfires
Example #4
0
    def compute_test_infl(self):
        num_subsets = len(self.R['subset_indices'])
        subsets_per_batch = 256
        results = self.task_queue.execute(
            'compute_cex_test_infl_batch',
            [(i, min(i + subsets_per_batch, num_subsets))
             for i in range(0, num_subsets, subsets_per_batch)],
            force_refresh=True)

        res = self.task_queue.collate_results(results)

        ds_test = DataSet(self.R['cex_X'], self.R['cex_Y'])
        model = self.get_model()
        model.load('initial')
        test_grad_losses = model.get_indiv_grad_loss(ds_test, verbose=False)
        test_grad_margins = model.get_indiv_grad_margin(ds_test, verbose=False)

        pred_dparam = self.R['subset_pred_dparam']
        newton_pred_dparam = self.R['subset_newton_pred_dparam']
        res['cex_subset_test_pred_infl'] = np.dot(pred_dparam,
                                                  test_grad_losses.T)
        res['cex_subset_test_pred_margin_infl'] = np.dot(
            pred_dparam, test_grad_margins.T)
        res['cex_subset_test_newton_pred_infl'] = np.dot(
            newton_pred_dparam, test_grad_losses.T)
        res['cex_subset_test_newton_pred_margin_infl'] = np.dot(
            newton_pred_dparam, test_grad_margins.T)

        return res
Example #5
0
def load_mnist(validation_size=5000, data_dir=None):
    SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

    dataset_dir = get_dataset_dir('mnist', data_dir=data_dir)

    local_files = [
        maybe_download(SOURCE_URL + image_file, image_file, dataset_dir)
        for image_file in (TRAIN_IMAGES, TRAIN_LABELS, TEST_IMAGES,
                           TEST_LABELS)
    ]

    with open(local_files[0], 'rb') as f:
        train_images = extract_images(f)

    with open(local_files[1], 'rb') as f:
        train_labels = extract_labels(f)

    with open(local_files[2], 'rb') as f:
        test_images = extract_images(f)

    with open(local_files[3], 'rb') as f:
        test_labels = extract_labels(f)

    if not 0 <= validation_size <= len(train_images):
        raise ValueError(
            'Validation size should be between 0 and {}. Received: {}.'.format(
                len(train_images), validation_size))

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    train_images = train_images.astype(np.float32) / 255
    validation_images = validation_images.astype(np.float32) / 255
    test_images = test_images.astype(np.float32) / 255

    train = DataSet(train_images, train_labels)
    validation = DataSet(validation_images, validation_labels)
    test = DataSet(test_images, test_labels)

    return base.Datasets(train=train, validation=validation, test=test)
Example #6
0
 def get_indiv_grad_loss_from_total_grad(self, dataset, **kwargs):
     indiv_grad_loss = self.batch_evaluate(lambda xs, labels: [
         self.get_total_grad_loss(DataSet(xs, labels), l2_reg=0, **kwargs)
     ],
                                           lambda v1, v2: v1.extend(v2) or
                                           v1,
                                           1,
                                           dataset,
                                           value_name="Gradients")
     return np.array(indiv_grad_loss)
Example #7
0
def load_spam(truncate=None, data_dir=None):
    dataset_dir = get_dataset_dir('spam', data_dir=data_dir)
    spam_path = os.path.join(dataset_dir,
                             'spam_truncate-{}.npz'.format(truncate))

    if not os.path.exists(spam_path):
        SPAM_URL = "http://www.aueb.gr/users/ion/data/enron-spam/preprocessed/enron1.tar.gz"
        raw_spam_path = maybe_download(SPAM_URL, 'enron1.tar.gz', dataset_dir)

        print("Extracting {}".format(raw_spam_path))
        with tarfile.open(raw_spam_path, 'r:gz') as tarf:
            tarf.extractall(path=dataset_dir)

        print("Processing spam")
        X_train, Y_train, X_valid, Y_valid, X_test, Y_test = process_spam(
            dataset_dir, truncate)

        # Convert them to dense matrices
        X_train = X_train.toarray()
        X_valid = X_valid.toarray()
        X_test = X_test.toarray()

        np.savez(spam_path,
                 X_train=X_train,
                 Y_train=Y_train,
                 X_test=X_test,
                 Y_test=Y_test,
                 X_valid=X_valid,
                 Y_valid=Y_valid)
    else:
        data = np.load(spam_path)
        X_train = data['X_train']
        Y_train = data['Y_train']
        X_test = data['X_test']
        Y_test = data['Y_test']
        X_valid = data['X_valid']
        Y_valid = data['Y_valid']

    train = DataSet(X_train, Y_train)
    validation = DataSet(X_valid, Y_valid)
    test = DataSet(X_test, Y_test)

    return base.Datasets(train=train, validation=validation, test=test)
Example #8
0
 def get_dataset(self, dataset_id=None):
     dataset_id = dataset_id if dataset_id is not None else self.dataset_id
     if not hasattr(self, 'datasets'):
         self.datasets = dict()
     if not dataset_id in self.datasets:
         ds_keys = [
             '{}_{}'.format(dataset_id, key)
             for key in ('train_X', 'train_Y', 'test_X', 'test_Y')
         ]
         if any(ds_key not in self.R for ds_key in ds_keys):
             raise ValueError('Dataset gauss has not been generated')
         train_X, train_Y, test_X, test_Y = [
             self.R[ds_key] for ds_key in ds_keys
         ]
         train = DataSet(train_X, train_Y)
         test = DataSet(test_X, test_Y)
         self.datasets[dataset_id] = base.Datasets(train=train,
                                                   test=test,
                                                   validation=None)
     return self.datasets[dataset_id]
def load_animals(data_dir=None):
    dataset_dir = get_dataset_dir('processed_animals', data_dir=data_dir)

    BASE_URL = "http://mitra.stanford.edu/kundaje/pangwei/"
    TRAIN_FILE_NAME = "animals_900_300_inception_features_train.npz"
    TEST_FILE_NAME = "animals_900_300_inception_features_test.npz"
    train_path = maybe_download(BASE_URL + TRAIN_FILE_NAME, TRAIN_FILE_NAME,
                                dataset_dir)
    test_path = maybe_download(BASE_URL + TEST_FILE_NAME, TEST_FILE_NAME,
                               dataset_dir)

    data_train = np.load(train_path)
    data_test = np.load(test_path)

    X_train = data_train['inception_features_val']
    Y_train = data_train['labels'].astype(np.uint8)
    X_test = data_test['inception_features_val']
    Y_test = data_test['labels'].astype(np.uint8)

    train = DataSet(X_train, Y_train)
    test = DataSet(X_test, Y_test)
    return base.Datasets(train=train, validation=None, test=test)
Example #10
0
def load_mnli(data_dir=None, non_tf=False):
    dataset_dir = get_dataset_dir('multinli_1.0', data_dir=data_dir)
    path = os.path.join(dataset_dir, 'mnli.npz')

    print('Loading mnli from {}.'.format(path))
    data = np.load(path, allow_pickle=True)
    labels = [data['lab0'],[],data['lab2']]
    x = [data['x0'],[],data['x2']]

    mask = range(600)
    x[0] = np.array(x[0])[:, mask]
    x[1] = np.array([],dtype=np.float32)
    labels[1] = np.array([],dtype=np.float32)
    x[2] = np.array(x[2])[:, mask]

    train = DataSet(x[0], labels[0])
    validation = DataSet(x[1], labels[1])
    test = DataSet(x[2], labels[2])

    print('Loaded mnli.')

    if non_tf: return (train, validation, test)
    return base.Datasets(train=train, validation=validation, test=test)
def load_small_cifar10(validation_size=1000, random_seed=0, data_dir=None):
    dataset_dir = get_dataset_dir('cifar10', data_dir=data_dir)

    data_sets = load_cifar10(validation_size, data_dir=data_dir)
    rng = np.random.RandomState(random_seed)

    train_images = data_sets.train.x
    train_labels = data_sets.train.labels
    perm = np.arange(len(train_labels))
    rng.shuffle(perm)
    num_to_keep = int(len(train_labels)/10)
    perm = perm[:num_to_keep]
    train_images = train_images[perm,:]
    train_labels = train_labels[perm]

    validation_images = data_sets.validation.x
    validation_labels = data_sets.validation.labels
    test_images = data_sets.test.x
    test_labels = data_sets.test.labels

    train = DataSet(train_images, train_labels)
    validation = DataSet(validation_images, validation_labels)
    test = DataSet(test_images, test_labels)
    return base.Datasets(train=train, validation=validation, test=test)
Example #12
0
def load_mnli_nonfires(data_dir=None):
    dataset_dir = get_dataset_dir('multinli_1.0', data_dir=data_dir)
    path = os.path.join(dataset_dir, 'mnli_nonfires.npz')

    print('Loading nonfires from {}.'.format(path))
    data = np.load(path, allow_pickle=True)
    x = np.array(data['x'])
    labels = np.array(data['labels'])

    mask = range(600)
    x = x[:, mask]

    nonfires = DataSet(x, labels)

    print('Loaded mnli nonfires.')
    return nonfires
Example #13
0
    def compute_cex_test_infl_batch(self, subset_start, subset_end):
        self.load_phases([0, 1, 2, 3, 4], verbose=False)

        start_time = time.time()

        ds_test = DataSet(self.R['cex_X'], self.R['cex_Y'])
        res = dict(('cex_' + key, value)
                   for key, value in self.compute_test_influence(
                       subset_start, subset_end, ds_test).items())

        end_time = time.time()
        time_per_subset = (end_time - start_time) / (subset_end - subset_start)
        remaining_time = (len(self.R['subset_indices']) -
                          subset_end) * time_per_subset
        print('Each subset takes {} s, {} s remaining'.format(
            time_per_subset, remaining_time))

        return res
def load_hospital(data_dir=None):
    dataset_dir = get_dataset_dir('hospital', data_dir=data_dir)
    hospital_path = os.path.join(dataset_dir, 'hospital.npz')

    if not os.path.exists(hospital_path):
        HOSPITAL_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00296/dataset_diabetes.zip"

        raw_path = maybe_download(HOSPITAL_URL, 'dataset_diabetes.zip',
                                  dataset_dir)

        with zipfile.ZipFile(raw_path, 'r') as zipf:
            zipf.extractall(path=dataset_dir)

        csv_path = os.path.join(dataset_dir, 'dataset_diabetes',
                                'diabetic_data.csv')
        df = pd.read_csv(csv_path)

        # Convert categorical variables into numeric ones

        rng = np.random.RandomState(2)
        X = pd.DataFrame()

        # Numerical variables that we can pull directly
        X = df.loc[:, [
            'time_in_hospital', 'num_lab_procedures', 'num_procedures',
            'num_medications', 'number_outpatient', 'number_emergency',
            'number_inpatient', 'number_diagnoses'
        ]]

        categorical_var_names = [
            'gender', 'race', 'age', 'discharge_disposition_id',
            'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide',
            'nateglinide', 'chlorpropamide', 'glimepiride', 'acetohexamide',
            'glipizide', 'glyburide', 'tolbutamide', 'pioglitazone',
            'rosiglitazone', 'acarbose', 'miglitol', 'troglitazone',
            'tolazamide', 'examide', 'citoglipton', 'insulin',
            'glyburide-metformin', 'glipizide-metformin',
            'glimepiride-pioglitazone', 'metformin-rosiglitazone',
            'metformin-pioglitazone', 'change', 'diabetesMed'
        ]
        for categorical_var_name in categorical_var_names:
            categorical_var = pd.Categorical(df.loc[:, categorical_var_name])

            # Just have one dummy variable if it's boolean
            if len(categorical_var.categories) == 2:
                drop_first = True
            else:
                drop_first = False

            dummies = pd.get_dummies(categorical_var,
                                     prefix=categorical_var_name,
                                     drop_first=drop_first)

            X = pd.concat([X, dummies], axis=1)

        ### Set the Y labels
        readmitted = pd.Categorical(df.readmitted)

        Y = np.copy(readmitted.codes)

        # Combine >30 and 0 and flip labels, so 1 (>30) and 2 (No) become -1, while 0 becomes 1
        Y[Y >= 1] = -1
        Y[Y == 0] = 1

        # Map to feature names
        feature_names = X.columns.values

        ### Find indices of age features
        age_var = pd.Categorical(df.loc[:, 'age'])
        age_var_names = [
            'age_%s' % age_var_name for age_var_name in age_var.categories
        ]
        age_var_indices = []
        for age_var_name in age_var_names:
            age_var_indices.append(
                np.where(X.columns.values == age_var_name)[0][0])
        age_var_indices = np.array(age_var_indices, dtype=int)

        ### Split into training and test sets.
        # For convenience, we balance the training set to have 10k positives and 10k negatives.
        num_examples = len(Y)
        assert X.shape[0] == num_examples
        num_train_examples = 20000
        num_train_examples_per_class = int(num_train_examples / 2)
        num_test_examples = num_examples - num_train_examples
        assert num_test_examples > 0

        pos_idx = np.where(Y == 1)[0]
        neg_idx = np.where(Y == -1)[0]
        rng.shuffle(pos_idx)
        rng.shuffle(neg_idx)
        assert len(pos_idx) + len(neg_idx) == num_examples

        train_idx = np.concatenate((pos_idx[:num_train_examples_per_class],
                                    neg_idx[:num_train_examples_per_class]))
        test_idx = np.concatenate((pos_idx[num_train_examples_per_class:],
                                   neg_idx[num_train_examples_per_class:]))
        rng.shuffle(train_idx)
        rng.shuffle(test_idx)

        X_train = np.array(X.iloc[train_idx, :], dtype=np.float32)
        Y_train = Y[train_idx]

        X_test = np.array(X.iloc[test_idx, :], dtype=np.float32)
        Y_test = Y[test_idx]

        lr_Y_train = np.array((Y_train + 1) / 2, dtype=int)
        lr_Y_test = np.array((Y_test + 1) / 2, dtype=int)

        #test_children_idx = np.where(X_test[:, age_var_indices[0]] == 1)[0]

        np.savez(hospital_path,
                 X_train=X_train,
                 Y_train=Y_train,
                 lr_Y_train=lr_Y_train,
                 X_test=X_test,
                 Y_test=Y_test,
                 lr_Y_test=lr_Y_test)
    else:
        data = np.load(hospital_path)
        X_train = data['X_train']
        Y_train = data['Y_train']
        lr_Y_train = data['lr_Y_train']
        X_test = data['X_test']
        Y_test = data['Y_test']
        lr_Y_test = data['lr_Y_test']

    train = DataSet(X_train, Y_train)
    validation = None
    test = DataSet(X_test, Y_test)
    data_sets = base.Datasets(train=train, validation=validation, test=test)

    lr_train = DataSet(X_train, lr_Y_train)
    lr_validation = None
    lr_test = DataSet(X_test, lr_Y_test)
    lr_data_sets = base.Datasets(train=lr_train,
                                 validation=lr_validation,
                                 test=lr_test)

    return lr_data_sets
Example #15
0
def load_nonfires(ds_name, source_url, data_dir=None, force_refresh=False):
    dataset_dir = get_dataset_dir('babble', data_dir=data_dir)
    path = os.path.join(dataset_dir, '{}_nonfires'.format(ds_name))

    if not os.path.exists(path + 'labels.npz') or force_refresh:

        raw_path = maybe_download(source_url, '{}.db'.format(ds_name),
                                  dataset_dir)

        print('Extracting {}.'.format(raw_path))
        conn = sqlite3.connect(raw_path)
        # ex_id = candidate_id
        LF_labels = pd.read_sql_query("select * from label;",
                                      conn)  # label, ex_id, LF_id
        splits = pd.read_sql_query("select id, split from candidate;",
                                   conn)  #ex_id, train/dev/test split (0-2)
        features = pd.read_sql_query("select * from feature;",
                                     conn)  # feature value, ex_id, feature_id
        gold_labels = pd.read_sql_query(
            "select value, candidate_id from gold_label;", conn)  # gold, ex_id

        # all values are probably 1.0

        conn.close()

        split_ids = [splits['id'][splits['split'] == i]
                     for i in range(3)]  # ex_ids in each split
        ids_dups = np.array(LF_labels['candidate_id'])  # ex_id for each LF_ex
        which_split = [np.isin(ids_dups, split_ids[i])
                       for i in range(3)]  # which LF_ex in each split
        print('Extracted labels.')

        # Look through all LF labels and record seen candidate ids
        # Look through all candidates and find the non-seen ones
        # Collect the features of those

        _num_features = max(features['key_id'])

        print('Finding nonfires.')
        LF_labeled = set()
        for _, ex_id, _ in tqdm(np.array(LF_labels)):
            LF_labeled.add(ex_id)

        print('Collecting labels.')
        labels = []
        ex_id_to_ind = {}
        count = 0
        for lab, ex_id in tqdm(np.array(gold_labels)):
            if ex_id not in LF_labeled:
                labels.append((lab + 1) / 2)
                ex_id_to_ind[ex_id] = count
                count += 1
        labels = np.array(labels)
        _num_ex = count

        print('Collecting features.')
        x = dok_matrix((_num_ex, _num_features))
        for val, ex_id, f_id in tqdm(np.array(features)):
            if ex_id not in LF_labeled:
                x[ex_id_to_ind[ex_id], f_id - 1] = val
        x = x.astype(np.float32)

        save_npz(path + 'x.npz', coo_matrix(x))
        np.savez(path + 'labels.npz', labels=[labels])
    else:
        print('Loading {} nonfires.'.format(ds_name))
        data = np.load(path + 'labels.npz', allow_pickle=True)
        labels = data['labels'][0]
        x = load_npz(path + 'x.npz')

    nonfires = DataSet(x, labels)

    print('Loaded {} nonfires.'.format(ds_name))
    return nonfires
Example #16
0
def load(ds_name,
         source_url,
         validation_size=None,
         data_dir=None,
         force_refresh=False):
    dataset_dir = get_dataset_dir('babble', data_dir=data_dir)
    path = os.path.join(dataset_dir, '{}'.format(ds_name))

    if not os.path.exists(path + 'labels.npz') or force_refresh:

        raw_path = maybe_download(source_url, '{}.db'.format(ds_name),
                                  dataset_dir)

        print('Extracting {}.'.format(raw_path))
        conn = sqlite3.connect(raw_path)
        # ex_id = candidate_id
        LF_labels = pd.read_sql_query("select * from label;",
                                      conn)  # label, ex_id, LF_id
        features = pd.read_sql_query("select * from feature;",
                                     conn)  # feature value, ex_id, feature_id
        # all values are probably 1.0
        splits = pd.read_sql_query("select id, split from candidate;",
                                   conn)  #ex_id, train/dev/test split (0-2)
        start_test_ind = np.min(
            splits['id'][splits['split'] == 2])  # not the 1-indexing
        test_gold_labels = pd.read_sql_query(
            "select value, candidate_id from gold_label where candidate_id>{} order by candidate_id asc;"
            .format(start_test_ind - 1), conn)  # gold, ex_id
        conn.close()

        split_ids = [splits['id'][splits['split'] == i]
                     for i in range(3)]  # ex_ids in each split
        ids_dups = np.array(LF_labels['candidate_id'])  # ex_id for each LF_ex
        which_split = [np.isin(ids_dups, split_ids[i])
                       for i in range(3)]  # which LF_ex in each split
        labels = [
            np.array(LF_labels['value'][which_split[i]]) for i in range(3)
        ]  # label for each LF_ex
        print('Extracted labels.')

        _num_features = max(features['key_id'])
        _num_ex = splits.shape[0]
        _num_LF_ex = [labels[i].shape[0] for i in range(3)]

        print('Creating map from examples to sparse features.')
        last_seen = features['candidate_id'][0]
        count = 0
        ex_id_to_ind = {
            last_seen: count
        }  # in case the ex_id aren't consecutive
        ind_to_features = [dict() for i in range(_num_ex + 1)]
        for val, ex_id, key_id in tqdm(np.array(features)):
            if ex_id > last_seen:
                count += 1
                last_seen = ex_id
                ex_id_to_ind[last_seen] = count
            ind_to_features[count][key_id] = val

        print('Creating sparse feature matrices for LF examples.')
        x = [dok_matrix((_num_LF_ex[i], _num_features)) for i in range(3)]
        counts = [0 for i in range(3)]
        for _, ex_id, _ in tqdm(np.array(LF_labels)):
            split = splits['split'][ex_id_to_ind[ex_id]]
            for key_id, val in ind_to_features[
                    ex_id_to_ind[ex_id]].iteritems():
                x[split][counts[split], key_id - 1] = val
            counts[split] += 1
        print('Extracted feature matrices.')

        print('Reverting test things to gold.')
        _num_test = sum(splits['split'] == 2)
        x[2] = dok_matrix((_num_test, _num_features))
        labels[2] = np.array(test_gold_labels['value'])
        count = 0
        for ex_id, split in tqdm(np.array(splits)):
            if split == 2:
                for key_id, val in ind_to_features[
                        ex_id_to_ind[ex_id]].iteritems():
                    x[2][count, key_id - 1] = val
                count += 1

        labels = [((labels[i] + 1) / 2)
                  for i in range(3)]  # convert (-1,1) to (0,1)
        labels = [labels[i] for i in range(3)]

        for i in range(3):
            save_npz(path + 'x{}.npz'.format(i), coo_matrix(x[i]))
        np.savez(path + 'labels.npz', labels=labels)
    else:
        print('Loading {}.'.format(ds_name))
        data = np.load(path + 'labels.npz', allow_pickle=True)
        labels = data['labels']
        x = []
        for i in range(3):
            x.append(load_npz(path + 'x{}.npz'.format(i)))

    train = DataSet(x[0], labels[0])
    validation = DataSet(x[1], labels[1])
    test = DataSet(x[2], labels[2])

    print('Loaded {}.'.format(ds_name))
    return base.Datasets(train=train, validation=validation, test=test)
def load_cifar10(validation_size=1000, data_dir=None):
    dataset_dir = get_dataset_dir('cifar10', data_dir=data_dir)
    cifar10_path = os.path.join(dataset_dir, 'cifar10.npz')

    img_size = 32
    num_channels = 3
    img_size_flat = img_size * img_size * num_channels
    num_classes = 10

    _num_files_train = 5
    _images_per_file = 10000
    _num_images_train = _num_files_train * _images_per_file

    if not os.path.exists(cifar10_path):
        def _unpickle(f):
            f = os.path.join(dataset_dir, 'cifar-10-batches-py', f)
            print("Loading data: " + f)
            with open(f, 'rb') as fo:
                data_dict = pickle.load(fo)
            return data_dict

        def _convert_images(raw):
            raw_float = np.array(raw, dtype=float) / 255.0
            images = raw_float.reshape([-1, num_channels, img_size, img_size])
            images = images.transpose([0,2,3,1])
            return images

        def _load_data(f):
            data = _unpickle(f)
            raw = data[b'data']
            labels = np.array(data[b'labels'])
            images = _convert_images(raw)
            return images, labels

        def _load_class_names():
            raw = _unpickle(f='batches.meta')[b'label_names']
            return [x.decode('utf-8') for x in raw]

        def _load_training_data():
            images = np.zeros(shape=[_num_images_train, img_size, img_size, num_channels], dtype=float)
            labels = np.zeros(shape=[_num_images_train], dtype=int)

            begin = 0
            for i in range(_num_files_train):
                images_batch, labels_batch = _load_data(f='data_batch_'+str(i+1))
                end = begin + len(images_batch)
                images[begin:end,:] = images_batch
                labels[begin:end] = labels_batch
                begin = end
            return images, labels, _one_hot_encoded(class_numbers=labels, num_classes=num_classes)

        def _load_test_data():
            images, labels = _load_data(f='test_batch')
            return images, labels, _one_hot_encoded(class_numbers=labels, num_classes=num_classes)

        SOURCE_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        raw_cifar10_path = maybe_download(SOURCE_URL, 'cifar-10-python.tar.gz', dataset_dir)

        print('Extracting {}.'.format(raw_cifar10_path))
        with tarfile.open(raw_cifar10_path, 'r:gz') as tarf:
            tarf.extractall(path=dataset_dir)

        # no one-hot encoding of labels
        train_images, train_labels, _ = _load_training_data()
        test_images, test_labels, _ = _load_test_data()
        names = _load_class_names()

        validation_images = train_images[:validation_size]
        validation_labels = train_labels[:validation_size]
        train_images = train_images[validation_size:]
        train_labels = train_labels[validation_size:]

        np.savez(cifar10_path,
                 train_images=train_images,
                 train_labels=train_labels,
                 validation_images=validation_images,
                 validation_labels=validation_labels,
                 test_images=test_images,
                 test_labels=test_labels)
    else:
        data = np.load(cifar10_path)
        train_images = data['train_images']
        train_labels = data['train_labels']
        validation_images = data['validation_images']
        validation_labels = data['validation_labels']
        test_images = data['test_images']
        test_labels = data['test_labels']

    train = DataSet(train_images, train_labels)
    validation = DataSet(validation_images, validation_labels)
    test = DataSet(test_images, test_labels)

    return base.Datasets(train=train, validation=validation, test=test)