Beispiel #1
0
def inf_train_gen(DATASET, BATCH_SIZE):
    if DATASET == '25gaussians':
        dataset = []
        for i in range(int(100000 / 25)):
            for x in range(-2, 3):
                for y in range(-2, 3):
                    point = np.random.randn(2) * 0.05
                    point[0] += 2 * x
                    point[1] += 2 * y
                    dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        np.random.shuffle(dataset)
        dataset /= 2.828  # stdev
        while True:
            for i in range(int(len(dataset) / BATCH_SIZE)):
                yield dataset[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

    elif DATASET == 'stacked_mnist':
        ds = Reader.DS('stacked_train.npy')
        np.save('../dataset/Stacked_MNIST/dist_info.npy', ds.labels)
        while True:
            # p = np.random.permutation(a.size)
            # dd = ds.data[p]
            # ll = ds.labels[p]
            for i in range(int(ds.size / BATCH_SIZE)):
                start = i * BATCH_SIZE
                end = (i + 1) * BATCH_SIZE
                yield ds.images[start:end], ds.labels[start:end]

    elif DATASET == '1200D':
        ds = np.load('../dataset/1200D/1200D_train.npy')
        ds = ds.item()
        ds_images = ds['images']
        ds_dist = ds['y_dist']
        means = ds['means']
        ds_size = ds_images.shape[0]
        np.save('../dataset/1200D/dist_ydist.npy', ds_dist)
        np.save('../dataset/1200D/dist_means.npy', means)
        while True:
            for i in range(int(ds_size / BATCH_SIZE)):
                start = i * BATCH_SIZE
                end = (i + 1) * BATCH_SIZE
                yield ds_images[start:end], -1

    elif DATASET == 'swissroll':
        while True:
            data = datasets.make_swiss_roll(n_samples=BATCH_SIZE,
                                            noise=0.25)[0]
            data = data.astype('float32')[:, [0, 2]]
            data /= 7.5  # stdev plus a little
            yield data

    elif DATASET == '8gaussians':
        scale = 2.
        centers = [(1, 0), (-1, 0), (0, 1), (0, -1),
                   (1. / np.sqrt(2), 1. / np.sqrt(2)),
                   (1. / np.sqrt(2), -1. / np.sqrt(2)),
                   (-1. / np.sqrt(2), 1. / np.sqrt(2)),
                   (-1. / np.sqrt(2), -1. / np.sqrt(2))]
        centers = [(scale * x, scale * y) for x, y in centers]
        while True:
            dataset = []
            for i in range(BATCH_SIZE):
                point = np.random.randn(2) * .02
                center = random.choice(centers)
                point[0] += center[0]
                point[1] += center[1]
                dataset.append(point)
            dataset = np.array(dataset, dtype='float32')
            dataset /= 1.414  # stdev
            yield dataset