Exemple #1
0
def main_show():
    """
    Use glumpy to launch a data set viewer.
    """
    self = CIFAR10()
    Y = [m['label'] for m in self.meta]
    glumpy_viewer(img_array=self._pixels,
                  arrays_to_print=[Y],
                  window_shape=(32 * 4, 32 * 4))
Exemple #2
0
    def load(self):
        from skdata.cifar10.dataset import CIFAR10
        c = CIFAR10()
        len(c.meta)
        pix = np.float32(c._pixels / 255.)
        self.data_shape = pix.shape[1:]
        assert self.data_shape == (32, 32, 3)
        pix = pix.reshape(60000, np.prod(self.data_shape))
        lbl = c._labels

        assert self.testset_size <= 10000
        t = self.testset_size
        self.data = {
            'trn': [pix[:40000], lbl[:40000]],
            'val': [pix[40000:40000 + t], lbl[40000:40000 + t]],
            'tst': [pix[50000:50000 + t], lbl[50000:50000 + t]]
        }
Exemple #3
0
def get_cifar10(batch_size=16):
    print("loading cifar10 data ... ")

    from skdata.cifar10.dataset import CIFAR10
    cifar10 = CIFAR10()
    cifar10.fetch(True)

    trn_labels = []
    trn_pixels = []
    for i in range(1,6):
        data = cifar10.unpickle("data_batch_%d" % i)
        trn_pixels.append(data['data'])
        trn_labels.extend(data['labels'])

    trn_pixels = np.vstack(trn_pixels)
    trn_pixels = trn_pixels.reshape(-1, 3, 32, 32).astype(np.float32)

    tst_data = cifar10.unpickle("test_batch")
    tst_labels = tst_data["labels"]
    tst_pixels = tst_data["data"]
    tst_pixels = tst_pixels.reshape(-1, 3, 32, 32).astype(np.float32)

    print("trn.shape=%s tst.shape=%s" % (trn_pixels.shape, tst_pixels.shape))

    print("computing mean & stddev for cifar10 ...")
    mu = np.mean(trn_pixels, axis=(0,2,3))
    std = np.std(trn_pixels, axis=(0,2,3))
    print("cifar10 mu  = %s" % mu)
    print("cifar10 std = %s" % std)
    print("whitening cifar10 pixels ... ")

    trn_pixels[:, :, :, :] -= mu.reshape(1, 3, 1, 1)
    trn_pixels[:, :, :, :] /= std.reshape(1, 3, 1, 1)

    tst_pixels[:, :, :, :] -= mu.reshape(1, 3, 1, 1)
    tst_pixels[:, :, :, :] /= std.reshape(1, 3, 1, 1)

    # transpose to tensorflow's bhwc order assuming bchw order
    trn_pixels = trn_pixels.transpose(0, 2, 3, 1)
    tst_pixels = tst_pixels.transpose(0, 2, 3, 1)

    trn_set = batch_iterator(it.cycle(zip(trn_pixels, trn_labels)), batch_size, cycle=True, batch_fn=lambda x: zip(*x))
    tst_set = (np.vstack(tst_pixels), np.array(tst_labels))

    return trn_set, tst_set
Exemple #4
0
def get_cifar10(batch_size=16):
    from skdata.cifar10.dataset import CIFAR10
    cifar10 = CIFAR10()
    cifar10.fetch(True)

    trn_labels = []
    trn_pixels = []
    for i in range(1, 6):
        data = cifar10.unpickle("data_batch_%d" % i)
        trn_pixels.append(data['data'])
        trn_labels.extend(data['labels'])

    trn_pixels = np.vstack(trn_pixels)

    tst_data = cifar10.unpickle("test_batch")
    tst_labels = tst_data["labels"]
    tst_pixels = tst_data["data"]

    trn_set = batch_iterator(it.cycle(zip(trn_pixels, trn_labels)),
                             batch_size,
                             batch_fn=lambda x: zip(*x))
    tst_set = (np.vstack(tst_pixels), np.array(tst_labels))

    return trn_set, tst_set
Exemple #5
0
def get_cifar10(batch_size=16):
    print("loading cifar10 data ... ")

    from skdata.cifar10.dataset import CIFAR10
    cifar10 = CIFAR10()
    cifar10.fetch(True)

    trn_labels = []
    trn_pixels = []
    for i in range(1, 6):
        data = cifar10.unpickle("data_batch_%d" % i)
        trn_pixels.append(data['data'])
        trn_labels.extend(data['labels'])

    trn_pixels = np.vstack(trn_pixels)
    trn_pixels = trn_pixels.reshape(-1, 3, 32, 32).astype(np.float32)

    tst_data = cifar10.unpickle("test_batch")
    tst_labels = tst_data["labels"]
    tst_pixels = tst_data["data"]
    tst_pixels = tst_pixels.reshape(-1, 3, 32, 32).astype(np.float32)

    print("-- trn shape = %s" % list(trn_pixels.shape))
    print("-- tst shape = %s" % list(tst_pixels.shape))

    # transpose to tensorflow's bhwc order assuming bchw order
    trn_pixels = trn_pixels.transpose(0, 2, 3, 1)
    tst_pixels = tst_pixels.transpose(0, 2, 3, 1)

    trn_set = batch_iterator(it.cycle(zip(trn_pixels, trn_labels)),
                             batch_size,
                             cycle=True,
                             batch_fn=lambda x: zip(*x))
    tst_set = (tst_pixels, np.array(tst_labels))

    return trn_set, tst_set
Exemple #6
0
def main_clean_up():
    """
    Delete all memmaps and data set files related to CIFAR10.
    """
    CIFAR10().clean_up()
Exemple #7
0
def main_fetch():
    """
    Download the CIFAR10 data set to the skdata cache dir
    """
    CIFAR10().fetch(True)