class TestBatchProcessor(unittest.TestCase):

    def setUp(self):
        self.train_batch = BatchProcessor(
            X_dirpath='./tests/data/train/*',
            y_dirpath='./tests/data/test/',
            batchsize=100,
            border=1,
            limit=None,
            rnd=rnd)

        self.valid_batch = BatchProcessor(
            X_dirpath='../tests/data/valid/*',
            y_dirpath='./tests/data/test/',
            batchsize=100,
            border=1,
            limit=None,
            rnd=rnd)

    def tearDown(self):
        del self.train_batch, self.valid_batch

    def test_iterating(self):
        test1 = [(X,y) for X, y in self.train_batch]
        test2 = [(X,y) for X, y in self.train_batch]
        test3 = [(X,y) for X, y in self.train_batch]
        test4 = [(X,y) for X, y in self.train_batch]
        test5 = [(X,y) for X, y in self.train_batch]

    @unittest.skipUnless(config.slow, 'slow test')
    def test_bench(self):
        bp = BatchProcessor(
            X_dirpath='../data/train/*',
            y_dirpath='../data/train_cleaned/',
            batchsize=4000000,
            border=2,
            limit=None,
            rnd=rnd)

        start = timer()
        bp.random = False
        bp.slow = True
        for X, y in bp: None
        end = timer()
        print "slow:\t\t %d" % (end - start)

        start = timer()
        bp.random = True
        bp.slow = True
        for X, y in bp: None
        end = timer()
        print "slow rand:\t\t %d" % (end - start)

        start = timer()
        bp.random = False
        bp.slow = False
        for X, y in bp: None
        end = timer()
        print "fast:\t\t %d" % (end - start)

        start = timer()
        bp.random = True
        bp.slow = False
        for X, y in bp: None
        end = timer()
        print "fast rand:\t\t %d" % (end - start)

        start = timer()
        bp.random = True
        bp.random_mode = 'fully'
        bp.slow = False
        for X, y in bp: None
        end = timer()
        print "fully rand:\t\t %d" % (end - start)

    def test_iterating_fully_random(self):
        self.train_batch.random_mode = 'fully'
        fullset = [i for i in self.train_batch]
        fullset2 = [i for i in self.train_batch]

    def test_len(self):
        # Pixels manually calculated according test images
        pixels = 50*34*3 / self.train_batch.batchsize
        self.assertEqual(len(self.train_batch), pixels)

    def test_consitency(self):
        ds = self.train_batch.next()
        self.train_batch.reset()
        ds2 = self.train_batch.next()
        center_index = (len(ds[0][0]) - 1) / 2
        self.assertEqual(ds[0][0][center_index], ds2[0][0][center_index])

    def test_randomness(self):
        self.train_batch.random = True
        X, y = self.train_batch.next()
        self.train_batch.reset()
        X2, y2 = self.train_batch.next()
        center_index = (len(X[0]) - 1) / 2
        self.assertNotEqual(X[0][center_index], X2[0][center_index])
        self.assertNotEqual(X[-1][center_index], X2[-1][center_index])

        X3, y3 = self.train_batch.next()
        center_index = (len(X[0]) - 1) / 2
        self.assertNotEqual(X[0][center_index], X3[0][center_index])
        self.assertNotEqual(X[-1][center_index], X3[-1][center_index])

    def test_batch_lengths(self):
        validl = len(self.valid_batch)
        validl2 = len([x for x in self.valid_batch])
        self.assertEqual(validl, validl2)

        trainl = len(self.train_batch)
        trainl2 = len([x for x in self.train_batch])
        self.assertEqual(trainl, trainl2)

    def test_next_fully_random(self):
        set = self.train_batch.next_fully_random()
        set2 = self.train_batch.next()

        self.assertEqual(type(set), type(set2))
        self.assertEqual(set[0].shape, set2[0].shape)
        self.assertEqual(set[1].shape, set2[1].shape)

    def test_batch_sizes(self):
        eq = True
        for X, y in self.valid_batch:
          eq = (X.shape[0] == self.valid_batch.batchsize)
          if not eq: break
        self.assertEqual(eq, True)

        eq = True
        for X, y in self.train_batch:
          eq = (X.shape[0] == self.train_batch.batchsize)
          if not eq: break
        self.assertEqual(eq, True)