def test_seed(self):

        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        def get_sampler(seed):
            return samp.SeqSampler(dataset, true_frac=0.5, base_seed=seed)

        base = 0
        sampler = get_sampler(base)

        epoch_one = list(sample for sample in iter(sampler))

        sampler = get_sampler(base)

        epoch_two = list(sample for sample in iter(sampler))

        self.assertEqual(epoch_one, epoch_two)

        new_sampler = get_sampler(None)
        self.assertNotEqual(new_sampler.base_seed, None)

        # covering random edge case with probability 2**-32. lel
        while new_sampler.base_seed == 0:
            new_sampler = get_sampler(None)

        new_epoch_one = list(sample for sample in iter(new_sampler))

        self.assertNotEqual(epoch_one, new_epoch_one)
    def test_getitem(self):

        dataset = data.TrainSeqDataSet(TEST_DATASET)
        value = ((0, 0, slice(0, 20)), (0, 2, slice(10, 30)))

        res = dataset[value]

        self.assertEqual(len(res), 3)
        self.assertEqual(res[0].shape, (2, 20, 3, 120, 120))
        self.assertTrue(np.array_equal(res[1], np.array((20, 20))))
        self.assertEqual(res[2][0], res[2][1])

        value = ((0, 0, slice(0, 20)), (1, 0, slice(10, 30)))
        res = dataset[value]

        self.assertNotEqual(res[2][0], res[2][1])

        with self.assertRaises(ValueError):
            dataset[(0, np.arange(90))]

        with self.assertRaises(ValueError):
            dataset[0]

        with self.assertRaises(ValueError):
            wrong_value = ((1, slice(0, 20)), (0, slice(20, 40)))
            dataset[wrong_value + wrong_value]
    def test_size(self):

        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.0)
        self.assertTrue(len(sampler), int(np.floor(len(dataset) / 2)))

        sampler = samp.SeqSampler(dataset, true_frac=1.0)
        self.assertTrue(len(sampler), len(dataset))
        epoch = list(sample for sample in iter(sampler))
        for itx in range(len(epoch)):
            with self.subTest(i=itx):
                self.assertEqual(epoch[itx][0][0], epoch[itx][1][0])

        sampler = samp.SeqSampler(dataset, true_frac=0.0)
        self.assertTrue(len(sampler), int(np.floor(len(dataset))))

        epoch = list(sample for sample in iter(sampler))
        for itx in range(len(epoch)):
            with self.subTest(i=itx):
                self.assertNotEqual(epoch[itx][0][0], epoch[itx][1][0])

        self.assertEqual(samp.SeqSampler.get_max_size(10, 1.0), 10)
        self.assertEqual(samp.SeqSampler.get_max_size(10, 0.0), 5)
    def test_repetition(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15
        base = 0

        sampler = samp.RepeatingSeqSampler(dataset, base_seed=base)

        sampled_list = [
            list(sample for sample in iter(sampler)) for i in range(1, 11)
        ]

        for itx in range(len(sampled_list) - 1):
            with self.subTest(i=itx):
                self.assertEqual(*sampled_list[itx:itx + 2])

        new_sampler = samp.RepeatingSeqSampler(dataset, base_seed=base)
        new_sampled = list(sample for sample in iter(new_sampler))

        self.assertEqual(new_sampled, sampled_list[0])

        new_base = base + 1
        diff_sampler = samp.RepeatingSeqSampler(dataset, base_seed=new_base)
        diff_sampled = list(sample for sample in iter(diff_sampler))

        self.assertNotEqual(new_sampled, diff_sampled)
    def test_restore(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15
        base = 0

        from_one = samp.SeqSampler(dataset, start_epoch=1, base_seed=base)

        from_one_epochs = [
            list(sample for sample in iter(from_one)) for i in range(1, 11)
        ]

        from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=base)
        from_five_epochs = [
            list(sample for sample in iter(from_five)) for i in range(6, 11)
        ]

        self.assertEqual(from_one_epochs[5:], from_five_epochs)

        new_base = base + 1

        new_from_five = samp.SeqSampler(dataset,
                                        start_epoch=6,
                                        base_seed=new_base)
        new_from_five_epochs = [
            list(sample for sample in iter(new_from_five))
            for i in range(6, 11)
        ]

        self.assertNotEqual(new_from_five_epochs, from_five_epochs)
    def test_iter(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.5)

        sample_it = iter(sampler)

        for sample in sample_it:
            self.assertTrue(dataset._valid_t(sample[0]))
            self.assertTrue(dataset._valid_t(sample[1]))
    def test_make_sample(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.5)

        d_iter = (i for i in range(len(dataset)))

        rnd = np.random.RandomState(0)
        #       checking correct (two different seq from the same obj) sample
        sample = sampler.make_sample(1, d_iter, rnd)

        self.assertTrue(dataset._valid_t(sample[0]))
        self.assertTrue(dataset._valid_t(sample[1]))
        self.assertEqual(sample[0][0], sample[1][0])
        self.assertNotEqual(sample[0][1], sample[1][1])

        #       checking wrong sample
        sample = sampler.make_sample(0, d_iter, rnd)

        self.assertTrue(dataset._valid_t(sample[0]))
        self.assertTrue(dataset._valid_t(sample[1]))
        self.assertNotEqual(sample[0][0], sample[1][0])