Example #1
0
 def test_downloading_will_create_necessary_files(self):
     if not os.environ.get('run_slow', None):
         return
     download_dataset()
     self.assertTrue(self.file_exists('train-images-idx3-ubyte.gz',
                                      9912422))
     self.assertTrue(self.file_exists('train-labels-idx1-ubyte.gz', 28881))
     self.assertTrue(self.file_exists('t10k-images-idx3-ubyte.gz', 1648877))
     self.assertTrue(self.file_exists('t10k-labels-idx1-ubyte.gz', 4542))
Example #2
0
    def test_get_training_data(self):
        if not os.environ.get('run_slow', None):
            return
        download_dataset()
        X_train, Y_train = get_training_data()
        self.assertEqual(len(X_train), 60000)
        self.assertEqual(len(Y_train), 60000)

        self.assertTupleEqual(X_train[0].shape, (28 * 28, ))
        self.assertTupleEqual(Y_train[0].shape, (10, ))

        self.assertEqual(X_train[0].dtype, 'float64')
        self.assertEqual(Y_train[0].dtype, 'float64')
Example #3
0
    def test_get_test_data(self):
        if not os.environ.get('run_slow', None):
            return

        download_dataset()
        X, Y = get_test_data()
        self.assertEqual(len(X), 10000)
        self.assertEqual(len(Y), 10000)

        self.assertTupleEqual(X[0].shape, (28 * 28, ))
        self.assertTupleEqual(Y[0].shape, (10, ))

        self.assertEqual(X[0].dtype, 'float64')
        self.assertEqual(Y[0].dtype, 'float64')
Example #4
0
    def _load_mnist_examples(self):
        mnist.download_dataset()

        train_data = mnist.get_training_data()
        test_data = mnist.get_test_data()

        dataset_size = self._config['dataset_size']
        if dataset_size:
            train_size = dataset_size
            test_size = dataset_size
        else:
            train_size = len(train_data)
            test_size = len(test_data)

        self._data_src = PreloadSource(train_data[:train_size])
        self._test_data_src = PreloadSource(test_data[:test_size])
Example #5
0
def step(context):
    mnist.download_dataset()
    pixels_to_categories = mnist.get_training_data()
    generator = DigitGenerator()
    generator.train(pixels_to_categories=pixels_to_categories)
    context.generator = generator
Example #6
0
def step(context):
    mnist.download_dataset()
    context.training_data = PreloadSource(mnist.get_training_data())
    context.test_data = PreloadSource(mnist.get_test_data())
Example #7
0
import sys
import os

sys.path.insert(1, os.path.join(sys.path[0], '..'))
from digit_drawing import DigitGenerator
import helpers
from datasets import mnist

dest_folder = 'generated_digits'
image_width = 28
image_height = 28
nepochs = 5

gen = DigitGenerator()

mnist.download_dataset()
pixels_to_categories = mnist.get_training_data()
gen.train(pixels_to_categories=pixels_to_categories, nepochs=nepochs)

print('Training for {} epochs is complete'.format(nepochs))

for i in range(10):
    pixels = gen.generate_digit(i)
    helpers.create_image(dest_fname=os.path.join(dest_folder,
                                                 'digit_{}.png'.format(i)),
                         pixel_vector=pixels,
                         width=image_width,
                         height=image_height)
    print('Generated image of a digit {}'.format(i))