def test_downloading_will_create_necessary_files(self): if not os.environ.get('run_slow', None): return download_dataset() self.assertTrue(self.file_exists('train-images-idx3-ubyte.gz', 9912422)) self.assertTrue(self.file_exists('train-labels-idx1-ubyte.gz', 28881)) self.assertTrue(self.file_exists('t10k-images-idx3-ubyte.gz', 1648877)) self.assertTrue(self.file_exists('t10k-labels-idx1-ubyte.gz', 4542))
def test_get_training_data(self): if not os.environ.get('run_slow', None): return download_dataset() X_train, Y_train = get_training_data() self.assertEqual(len(X_train), 60000) self.assertEqual(len(Y_train), 60000) self.assertTupleEqual(X_train[0].shape, (28 * 28, )) self.assertTupleEqual(Y_train[0].shape, (10, )) self.assertEqual(X_train[0].dtype, 'float64') self.assertEqual(Y_train[0].dtype, 'float64')
def test_get_test_data(self): if not os.environ.get('run_slow', None): return download_dataset() X, Y = get_test_data() self.assertEqual(len(X), 10000) self.assertEqual(len(Y), 10000) self.assertTupleEqual(X[0].shape, (28 * 28, )) self.assertTupleEqual(Y[0].shape, (10, )) self.assertEqual(X[0].dtype, 'float64') self.assertEqual(Y[0].dtype, 'float64')
def _load_mnist_examples(self): mnist.download_dataset() train_data = mnist.get_training_data() test_data = mnist.get_test_data() dataset_size = self._config['dataset_size'] if dataset_size: train_size = dataset_size test_size = dataset_size else: train_size = len(train_data) test_size = len(test_data) self._data_src = PreloadSource(train_data[:train_size]) self._test_data_src = PreloadSource(test_data[:test_size])
def step(context): mnist.download_dataset() pixels_to_categories = mnist.get_training_data() generator = DigitGenerator() generator.train(pixels_to_categories=pixels_to_categories) context.generator = generator
def step(context): mnist.download_dataset() context.training_data = PreloadSource(mnist.get_training_data()) context.test_data = PreloadSource(mnist.get_test_data())
import sys import os sys.path.insert(1, os.path.join(sys.path[0], '..')) from digit_drawing import DigitGenerator import helpers from datasets import mnist dest_folder = 'generated_digits' image_width = 28 image_height = 28 nepochs = 5 gen = DigitGenerator() mnist.download_dataset() pixels_to_categories = mnist.get_training_data() gen.train(pixels_to_categories=pixels_to_categories, nepochs=nepochs) print('Training for {} epochs is complete'.format(nepochs)) for i in range(10): pixels = gen.generate_digit(i) helpers.create_image(dest_fname=os.path.join(dest_folder, 'digit_{}.png'.format(i)), pixel_vector=pixels, width=image_width, height=image_height) print('Generated image of a digit {}'.format(i))