def test_can_only_generate_digits_from_0_to_9(self): diggen = DigitGenerator() self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=-1)) self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=-5)) self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=-50)) self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=10)) self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=50))
def test_returns_non_uniform_array_after_training(self): x = [np.array([0] * 784, float)] y = [np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], float)] pixels_to_categories = (x, y) diggen = DigitGenerator() diggen.train(pixels_to_categories=pixels_to_categories) pixels = diggen.generate_digit(digit=1) expected_pixels = np.zeros(784, dtype=np.uint8) expected_pixels.fill(127) self.assertNotEqual(pixels.tolist(), expected_pixels.tolist())
import sys import os sys.path.insert(1, os.path.join(sys.path[0], '..')) from digit_drawing import DigitGenerator import helpers from datasets import mnist dest_folder = 'generated_digits' image_width = 28 image_height = 28 nepochs = 5 gen = DigitGenerator() mnist.download_dataset() pixels_to_categories = mnist.get_training_data() gen.train(pixels_to_categories=pixels_to_categories, nepochs=nepochs) print('Training for {} epochs is complete'.format(nepochs)) for i in range(10): pixels = gen.generate_digit(i) helpers.create_image(dest_fname=os.path.join(dest_folder, 'digit_{}.png'.format(i)), pixel_vector=pixels, width=image_width, height=image_height) print('Generated image of a digit {}'.format(i))
def test_method_outputs_numpy_array_of_integers(self): diggen = DigitGenerator() pixels = diggen.generate_digit(digit=5) self.assertIsInstance(pixels, np.ndarray) self.assertEqual(pixels.dtype, np.uint8)
def test_untrained_returns_uniform_array(self): diggen = DigitGenerator() pixels = diggen.generate_digit(digit=1) expected_pixels = np.zeros(784, dtype=np.uint8) expected_pixels.fill(127) self.assertEqual(pixels.tolist(), expected_pixels.tolist())
def test_argument_must_be_integer(self): diggen = DigitGenerator() self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=2.5)) self.assertRaises(DigitGenerator.InvalidDigitError, lambda: diggen.generate_digit(digit=[4]))
def test_works_for_valid_digits(self): diggen = DigitGenerator() for digit in range(10): diggen.generate_digit(digit=digit)
def test_returned_array_has_valid_shape(self): diggen = DigitGenerator() pixels = diggen.generate_digit(digit=1) self.assertTupleEqual(pixels.shape, (784, ))