def training_step(self, batch, *args): if SemiSupervisedIterator.is_labeled(batch): return self.supervised_training_step( SemiSupervisedIterator.get_batch(batch), *args) else: return self.unsupervised_training_step( SemiSupervisedIterator.get_batch(batch), *args)
def test_epoch(self): labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(self.ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) labeled_data = torch.tensor(labeled_data) unlabeled_data = torch.tensor(unlabeled_data) assert len(labeled_data) == len(unlabeled_data) assert torch.all(labeled_data % 2 == 0) assert torch.all(unlabeled_data % 2 != 0)
def test_p(self): ss_iterator = SemiSupervisedIterator(self.al_dataset, p=0.1, num_steps=None, batch_size=10) labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) total = len(labeled_data) + len(unlabeled_data) l_ratio = len(labeled_data) / total u_ratio = len(unlabeled_data) / total assert l_ratio < .15 assert u_ratio > 0.85
def test_no_pool(self): d1 = SSLTestDataset(labeled=True, length=100) al_dataset = ActiveLearningDataset(d1) al_dataset.label_randomly(100) ss_iterator = SemiSupervisedIterator(al_dataset, p=0.1, num_steps=None, batch_size=10) labeled_data = [] unlabeled_data = [] for batch_idx, batch in enumerate(ss_iterator): if SemiSupervisedIterator.is_labeled(batch): batch = SemiSupervisedIterator.get_batch(batch) labeled_data.extend(batch) else: batch = SemiSupervisedIterator.get_batch(batch) unlabeled_data.extend(batch) total = len(labeled_data) + len(unlabeled_data) l_ratio = len(labeled_data) / total u_ratio = len(unlabeled_data) / total assert l_ratio == 1 assert u_ratio == 0