def test_size(self):

        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.0)
        self.assertTrue(len(sampler), int(np.floor(len(dataset) / 2)))

        sampler = samp.SeqSampler(dataset, true_frac=1.0)
        self.assertTrue(len(sampler), len(dataset))
        epoch = list(sample for sample in iter(sampler))
        for itx in range(len(epoch)):
            with self.subTest(i=itx):
                self.assertEqual(epoch[itx][0][0], epoch[itx][1][0])

        sampler = samp.SeqSampler(dataset, true_frac=0.0)
        self.assertTrue(len(sampler), int(np.floor(len(dataset))))

        epoch = list(sample for sample in iter(sampler))
        for itx in range(len(epoch)):
            with self.subTest(i=itx):
                self.assertNotEqual(epoch[itx][0][0], epoch[itx][1][0])

        self.assertEqual(samp.SeqSampler.get_max_size(10, 1.0), 10)
        self.assertEqual(samp.SeqSampler.get_max_size(10, 0.0), 5)
    def test_restore(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15
        base = 0

        from_one = samp.SeqSampler(dataset, start_epoch=1, base_seed=base)

        from_one_epochs = [
            list(sample for sample in iter(from_one)) for i in range(1, 11)
        ]

        from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=base)
        from_five_epochs = [
            list(sample for sample in iter(from_five)) for i in range(6, 11)
        ]

        self.assertEqual(from_one_epochs[5:], from_five_epochs)

        new_base = base + 1

        new_from_five = samp.SeqSampler(dataset,
                                        start_epoch=6,
                                        base_seed=new_base)
        new_from_five_epochs = [
            list(sample for sample in iter(new_from_five))
            for i in range(6, 11)
        ]

        self.assertNotEqual(new_from_five_epochs, from_five_epochs)
    def test_iter(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.5)

        sample_it = iter(sampler)

        for sample in sample_it:
            self.assertTrue(dataset._valid_t(sample[0]))
            self.assertTrue(dataset._valid_t(sample[1]))
    def test_make_sample(self):
        dataset = data.TrainSeqDataSet(TEST_DATASET)
        dataset.pad = 15

        sampler = samp.SeqSampler(dataset, true_frac=0.5)

        d_iter = (i for i in range(len(dataset)))

        rnd = np.random.RandomState(0)
        #       checking correct (two different seq from the same obj) sample
        sample = sampler.make_sample(1, d_iter, rnd)

        self.assertTrue(dataset._valid_t(sample[0]))
        self.assertTrue(dataset._valid_t(sample[1]))
        self.assertEqual(sample[0][0], sample[1][0])
        self.assertNotEqual(sample[0][1], sample[1][1])

        #       checking wrong sample
        sample = sampler.make_sample(0, d_iter, rnd)

        self.assertTrue(dataset._valid_t(sample[0]))
        self.assertTrue(dataset._valid_t(sample[1]))
        self.assertNotEqual(sample[0][0], sample[1][0])
 def get_sampler(seed):
     return samp.SeqSampler(dataset, true_frac=0.5, base_seed=seed)