示例#1
0
class TestMNIST(unittest.TestCase):
    def setUp(self):
        # configure the root logger
        logger.config_root_logger()
        # get a logger for this session
        self.log = logging.getLogger(__name__)
        # get the mnist dataset
        self.mnist = MNIST(binary=False, concat_train_valid=True)

    def testSizes(self):
        assert self.mnist.getDataShape(dataset.TRAIN) == (60000, 784)
        assert self.mnist.getDataShape(dataset.VALID) == (10000, 784)
        assert self.mnist.getDataShape(dataset.TEST) == (10000, 784)

    def tearDown(self):
        del self.mnist
示例#2
0
class TestMNIST(unittest.TestCase):

    def setUp(self):
        # configure the root logger
        logger.config_root_logger()
        # get a logger for this session
        self.log = logging.getLogger(__name__)
        # get the mnist dataset
        self.mnist = MNIST(binary=False, concat_train_valid=True)

    def testSizes(self):
        assert self.mnist.getDataShape(dataset.TRAIN) == (60000, 784)
        assert self.mnist.getDataShape(dataset.VALID) == (10000, 784)
        assert self.mnist.getDataShape(dataset.TEST) == (10000, 784)


    def tearDown(self):
        del self.mnist
示例#3
0
class TestMNIST(unittest.TestCase):
    def setUp(self):
        # configure the root logger
        logger.config_root_logger()
        # get a logger for this session
        self.log = logging.getLogger(__name__)
        # get the mnist dataset
        self.mnist = MNIST(binary=False)
        # instantiate the sequential iterator
        self.sequentialIterator = SequentialIterator(self.mnist, dataset.TRAIN,
                                                     255, 255)
        # instantiate the random iterator
        self.randomIterator = RandomIterator(self.mnist, dataset.TRAIN, 255,
                                             255)

    def testSizes(self):
        assert self.mnist.getDataShape(dataset.TRAIN) == (60000, 784)
        assert self.mnist.getDataShape(dataset.VALID) == (10000, 784)
        assert self.mnist.getDataShape(dataset.TEST) == (10000, 784)

    def testSequentialIterator(self):
        self.log.debug('TESTING SEQUENTIAL ITERATOR')
        i = 0
        for _, y in self.sequentialIterator:
            if i < 2:
                self.log.debug(y)
            i += 1
        assert i == 235

    def testRandomIterator(self):
        self.log.debug('TESTING RANDOM ITERATOR')
        i = 0
        for x, y in self.randomIterator:
            if i < 2:
                self.log.debug(y)
            i += 1
        assert i == 235

    def tearDown(self):
        del self.mnist
        del self.sequentialIterator
        del self.randomIterator