Beispiel #1
0
 def training_step(self, batch, *args):
     if SemiSupervisedIterator.is_labeled(batch):
         return self.supervised_training_step(
             SemiSupervisedIterator.get_batch(batch), *args)
     else:
         return self.unsupervised_training_step(
             SemiSupervisedIterator.get_batch(batch), *args)
Beispiel #2
0
    def test_epoch(self):
        labeled_data = []
        unlabeled_data = []

        for batch_idx, batch in enumerate(self.ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        labeled_data = torch.tensor(labeled_data)
        unlabeled_data = torch.tensor(unlabeled_data)

        assert len(labeled_data) == len(unlabeled_data)
        assert torch.all(labeled_data % 2 == 0)
        assert torch.all(unlabeled_data % 2 != 0)
Beispiel #3
0
    def test_p(self):
        ss_iterator = SemiSupervisedIterator(self.al_dataset,
                                             p=0.1,
                                             num_steps=None,
                                             batch_size=10)

        labeled_data = []
        unlabeled_data = []
        for batch_idx, batch in enumerate(ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        total = len(labeled_data) + len(unlabeled_data)
        l_ratio = len(labeled_data) / total
        u_ratio = len(unlabeled_data) / total
        assert l_ratio < .15
        assert u_ratio > 0.85
Beispiel #4
0
    def test_no_pool(self):
        d1 = SSLTestDataset(labeled=True, length=100)
        al_dataset = ActiveLearningDataset(d1)
        al_dataset.label_randomly(100)
        ss_iterator = SemiSupervisedIterator(al_dataset,
                                             p=0.1,
                                             num_steps=None,
                                             batch_size=10)

        labeled_data = []
        unlabeled_data = []
        for batch_idx, batch in enumerate(ss_iterator):
            if SemiSupervisedIterator.is_labeled(batch):
                batch = SemiSupervisedIterator.get_batch(batch)
                labeled_data.extend(batch)
            else:
                batch = SemiSupervisedIterator.get_batch(batch)
                unlabeled_data.extend(batch)

        total = len(labeled_data) + len(unlabeled_data)
        l_ratio = len(labeled_data) / total
        u_ratio = len(unlabeled_data) / total
        assert l_ratio == 1
        assert u_ratio == 0