def test_seed(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 def get_sampler(seed): return samp.SeqSampler(dataset, true_frac=0.5, base_seed=seed) base = 0 sampler = get_sampler(base) epoch_one = list(sample for sample in iter(sampler)) sampler = get_sampler(base) epoch_two = list(sample for sample in iter(sampler)) self.assertEqual(epoch_one, epoch_two) new_sampler = get_sampler(None) self.assertNotEqual(new_sampler.base_seed, None) # covering random edge case with probability 2**-32. lel while new_sampler.base_seed == 0: new_sampler = get_sampler(None) new_epoch_one = list(sample for sample in iter(new_sampler)) self.assertNotEqual(epoch_one, new_epoch_one)
def test_getitem(self): dataset = data.TrainSeqDataSet(TEST_DATASET) value = ((0, 0, slice(0, 20)), (0, 2, slice(10, 30))) res = dataset[value] self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 20, 3, 120, 120)) self.assertTrue(np.array_equal(res[1], np.array((20, 20)))) self.assertEqual(res[2][0], res[2][1]) value = ((0, 0, slice(0, 20)), (1, 0, slice(10, 30))) res = dataset[value] self.assertNotEqual(res[2][0], res[2][1]) with self.assertRaises(ValueError): dataset[(0, np.arange(90))] with self.assertRaises(ValueError): dataset[0] with self.assertRaises(ValueError): wrong_value = ((1, slice(0, 20)), (0, slice(20, 40))) dataset[wrong_value + wrong_value]
def test_size(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.0) self.assertTrue(len(sampler), int(np.floor(len(dataset) / 2))) sampler = samp.SeqSampler(dataset, true_frac=1.0) self.assertTrue(len(sampler), len(dataset)) epoch = list(sample for sample in iter(sampler)) for itx in range(len(epoch)): with self.subTest(i=itx): self.assertEqual(epoch[itx][0][0], epoch[itx][1][0]) sampler = samp.SeqSampler(dataset, true_frac=0.0) self.assertTrue(len(sampler), int(np.floor(len(dataset)))) epoch = list(sample for sample in iter(sampler)) for itx in range(len(epoch)): with self.subTest(i=itx): self.assertNotEqual(epoch[itx][0][0], epoch[itx][1][0]) self.assertEqual(samp.SeqSampler.get_max_size(10, 1.0), 10) self.assertEqual(samp.SeqSampler.get_max_size(10, 0.0), 5)
def test_repetition(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 base = 0 sampler = samp.RepeatingSeqSampler(dataset, base_seed=base) sampled_list = [ list(sample for sample in iter(sampler)) for i in range(1, 11) ] for itx in range(len(sampled_list) - 1): with self.subTest(i=itx): self.assertEqual(*sampled_list[itx:itx + 2]) new_sampler = samp.RepeatingSeqSampler(dataset, base_seed=base) new_sampled = list(sample for sample in iter(new_sampler)) self.assertEqual(new_sampled, sampled_list[0]) new_base = base + 1 diff_sampler = samp.RepeatingSeqSampler(dataset, base_seed=new_base) diff_sampled = list(sample for sample in iter(diff_sampler)) self.assertNotEqual(new_sampled, diff_sampled)
def test_restore(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 base = 0 from_one = samp.SeqSampler(dataset, start_epoch=1, base_seed=base) from_one_epochs = [ list(sample for sample in iter(from_one)) for i in range(1, 11) ] from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=base) from_five_epochs = [ list(sample for sample in iter(from_five)) for i in range(6, 11) ] self.assertEqual(from_one_epochs[5:], from_five_epochs) new_base = base + 1 new_from_five = samp.SeqSampler(dataset, start_epoch=6, base_seed=new_base) new_from_five_epochs = [ list(sample for sample in iter(new_from_five)) for i in range(6, 11) ] self.assertNotEqual(new_from_five_epochs, from_five_epochs)
def test_iter(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.5) sample_it = iter(sampler) for sample in sample_it: self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1]))
def test_make_sample(self): dataset = data.TrainSeqDataSet(TEST_DATASET) dataset.pad = 15 sampler = samp.SeqSampler(dataset, true_frac=0.5) d_iter = (i for i in range(len(dataset))) rnd = np.random.RandomState(0) # checking correct (two different seq from the same obj) sample sample = sampler.make_sample(1, d_iter, rnd) self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1])) self.assertEqual(sample[0][0], sample[1][0]) self.assertNotEqual(sample[0][1], sample[1][1]) # checking wrong sample sample = sampler.make_sample(0, d_iter, rnd) self.assertTrue(dataset._valid_t(sample[0])) self.assertTrue(dataset._valid_t(sample[1])) self.assertNotEqual(sample[0][0], sample[1][0])