def test_size(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.0) self.assertTrue(len(sampler), int(np.floor(len(dataset) / 2))) sampler = samp.SeqSampler(dataset, true_frac=1.0) self.assertTrue(len(sampler), len(dataset)) epoch = list(sample for sample in iter(sampler)) for itx in range(len(epoch)): with self.subTest(i=itx): self.assertEqual(epoch[itx][0][0], epoch[itx][1][0]) sampler = samp.SeqSampler(dataset, true_frac=0.0) self.assertTrue(len(sampler), int(np.floor(len(dataset)))) epoch = list(sample for sample in iter(sampler)) for itx in range(len(epoch)): with self.subTest(i=itx): self.assertNotEqual(epoch[itx][0][0], epoch[itx][1][0]) self.assertEqual(samp.SeqSampler.get_max_size(10, 1.0), 10) self.assertEqual(samp.SeqSampler.get_max_size(10, 0.0), 5)
def test_restore(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 base = 0 from_one = samp.SeqSampler(dataset, start_epoch=1, base_seed=base) from_one_epochs = [ list(sample for sample in iter(from_one)) for i in range(1, 11) ] from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=base) from_five_epochs = [ list(sample for sample in iter(from_five)) for i in range(6, 11) ] self.assertEqual(from_one_epochs[5:], from_five_epochs) new_base = base + 1 new_from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=new_base) new_from_five_epochs = [ list(sample for sample in iter(new_from_five)) for i in range(6, 11) ] self.assertNotEqual(new_from_five_epochs, from_five_epochs)
def test_iter(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.5) sample_it = iter(sampler) for sample in sample_it: self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1]))
def test_make_sample(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.5) d_iter = (i for i in range(len(dataset))) rnd = np.random.RandomState(0) # checking correct (two different seq from the same obj) sample sample = sampler.make_sample(1, d_iter, rnd) self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1])) self.assertEqual(sample[0][0], sample[1][0]) self.assertNotEqual(sample[0][1], sample[1][1]) # checking wrong sample sample = sampler.make_sample(0, d_iter, rnd) self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1])) self.assertNotEqual(sample[0][0], sample[1][0])
def get_sampler(seed): return samp.SeqSampler(dataset, true_frac=0.5, base_seed=seed)