예제 #1
0
 def _inp_warn(self):
     logger.log_warn(
         'More than one ground truth of the task model is given in SSL_NULL\n'
         'You try to train the task model with more than one (pred & gt) pairs\n'
         'Please make sure that:\n'
         '  (1) The prediction tuple has the same size as the ground truth tuple\n'
         '  (2) The elements with the same index in the two tuples are corresponding\n'
     )
예제 #2
0
 def _algorithm_warn(self):
     logger.log_warn(
         'This SSL_CUTMIX algorith reproduces the SSL algorithm from the paper: \n'
         '  \'Semi-Supervised Semantic Segmentation Needs Strong, Varied Perturbations\'\n'
         'This implementation supports pixel-wise classification only due to the hyper-parameter: \n'
         '  \'cons-threshold\' \n'
         'The \'CutOut\' mode proposed by their paper is not implemented in this code\n'
     )
예제 #3
0
 def _algorithm_warn(self):
     logger.log_warn(
         'This SSL_S4L algorithm reproducts the SSL algorithm from paper:\n'
         '  \'S4L: Self-Supervised Semi-Supervised Learning\'\n'
         'The main differences between this implementation and the original paper are:\n'
         '  (1) This is an implementation for pixel-wise vision tasks\n'
         '  (2) This implementation only supports the 4-angle (0, 90, 180, 270) rotation-based self-supervised pretext task\n'
     )
예제 #4
0
 def _algorithm_warn(self):
     logger.log_warn('This SSL_CCT algorithm reproducts the SSL algorithm from paper:\n'
                     '  \'Semi-Supervised Semantic Segmentation with Cross-Consistency Training\'\n'
                     'The code of the auxiliary decoders are adapted from the official repository:\n'
                     '   https://github.com/yassouali/CCT \n'
                     'These auxiliary decoders may only suitable for pixel-wise classification\n'
                     'Hence, this implementation does not currently support pixel-wise regression tasks\n'
                     'Besides, the auxiliary decoders will use huge GPU memory\n'
                     'Please reduce the number of the auxiliary decoders if you run out of GPU memory\n')
예제 #5
0
파일: ssl_mt.py 프로젝트: lwzbuaa/PixelSSL
 def _inp_warn(self):
     logger.log_warn('More than one ground truth of the task model is given in SSL_MT\n'
                     'You try to train the task model with more than one (pred & gt) pairs\n'
                     'Please make sure that: \n'
                     '  (1) The prediction tuple has the same size as the ground truth tuple\n'
                     '  (2) The elements with the same index in the two tuples are corresponding\n'
                     '  (3) The first element of (pred & gt) will be used to calculate the consistency constraint\n'
                     'Please implement a new SSL algorithm if you want a variant of SSL_MT to\n' 
                     'calculate multiple consisteny constraints (for multiple predictions)\n')
예제 #6
0
파일: ssl_mt.py 프로젝트: lwzbuaa/PixelSSL
 def _algorithm_warn(self):
     logger.log_warn('This SSL_MT algorithm reproducts the SSL algorithm from paper:\n'
                     '  \'Mean Teachers are Better Role Models: Weight-Averaged Consistency Targets '
                     'Improve Semi-Supervised Deep Learning Results\'\n'
                     'The main differences between this implementation and the original paper are:\n'
                     '  (1) This is an implementation for pixel-wise vision tasks\n'
                     '  (2) The two-heads outputs trick is disable in this implementation\n'
                     '  (2) No extra perturbations between the inputs of the teacher and the student\n'
                     '      (The Gaussian noiser is provied, but it will degrade the performance)\n')
예제 #7
0
 def _algorithm_warn(self):
     logger.log_warn('This SSL_ADV algorithm reproducts the SSL algorithm from paper:\n'
                     '  \'Adversarial Learning for Semi-supervised Semantic Segmentation\'\n'
                     'The main differences between this implementation and the original paper are:\n'
                     '  (1) This implementation does not support the constraint named \'L_semi\' in the\n'
                     '      original paper since it can only be used for pixel-wise classification\n'
                     '\nThe semi-supervised constraint in this implementation refer to the constraint\n' 
                     'named \'L_adv\' in the original paper\n'
                     '\nSame as the original paper, the FC discriminator is trained by the Adam optimizer\n'
                     'with the PolynomialLR scheduler\n')
예제 #8
0
 def _inp_warn(self):
     logger.log_warn(
         'More than one ground truth of the task model is given in SSL_S4L\n'
         'You try to train the task model with more than one (pred & gt) pairs\n'
         'Please make sure that:\n'
         '  (1) The prediction tuple has the same size as the ground truth tuple\n'
         '  (2) The elements with the same index in the two tuples are corresponding\n'
         '  (3) All elements in the ground truth tuple should be 4-dim tensors since S4L\n'
         '      will rotate them to match the rotated inputs\n'
         'Please implement a new SSL algorithm if you want a variant of SSL_S4L that\n'
         'supports other formants (not 4-dim tensor) of the ground truth\n')
예제 #9
0
    def visualize(self, out_path, id_str='', inp=None, pred=None, gt=None):
        """ Visualize images during training/validation.

        Please refer to the implemented tasks to finish this function.

        Arguments:
            out_path (str): path to save the images
            id_str (str): identifier for recording
            inp (tuple): inputs of the task model
            pred (tuple): prediction of the task model
            gt (tuple): ground truth of the task model
        """

        logger.log_warn('No implementation of the \'visulize\' function for current task.\n'
                        'Please implement it in \'task/xxx/func.py\'.\n')
예제 #10
0
    def metrics(self, pred, gt, inp, meters, id_str=''):
        """ Calculate metrics for the task model.

        This function calculates all performance metrics for the task model
        and saves them into 'meters' (with prefix '[id_str]-[str_METRIC_STR]_xxx').
        
        Arguments:
            pred (torch.Tensor): prediction of the task model
            gt (torch.Tensor): ground truth of the task model
            meters (pixelssl.utils.AvgMeterSet): recorder
            id_str (str): identifier for recording
        """

        logger.log_warn('No implementation of the \'metrics\' function for current task.\n'
                        'Please implement it in \'task/xxx/func.py\'.\n')
예제 #11
0
    def step(self, epoch=None):
        if epoch is not None and epoch != 0:
            # update lr after each epoch if epoch is given
            # after each epoch, set epoch += 1 and call this function
            if not self.is_warn:
                logger.log_warn(
                    'PolynomialLR is designed for updating learning rate after each iteration.\n'
                    'However, it will be updated after each epoch now, please be careful.\n'
                )
                self.is_warn = True

            self.last_epoch = epoch
            assert self.last_epoch <= self.epochs
            self.cur_iter = self.last_epoch * self.iters_per_epoch

        elif epoch is None:
            # update lr after each iteration if epoch is None
            self.cur_iter += 1
            self.last_epoch = math.floor(self.cur_iter / self.iters_per_epoch)

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
예제 #12
0
 def _algorithm_warn(self):
     logger.log_warn(
         'This SSL_NULL algorithm is a fully-supervised baseline for SSL.\n'
     )
예제 #13
0
 def _data_err(self):
     logger.log_warn('More than one ground truth of the task model is given in SSL_CCT\n'
                     'Currently, this implementation of CCT algorithm supports only one (pred & gt) pairs\n'
                     'Please implement a new SSL algorithm if you want a variant of SSL_CCT that\n' 
                     'supports more than one (pred & gt) pairs\n')
예제 #14
0
 def _gt_warn(self):
     logger.log_warn(
         'More than one ground truth of the task model is given in SSL_CUTMIX\n'
         'You try to train the task model with more than one ground truth\n'
         'All ground truths are preprocessed with a CutMix operation using the same mask\n'
     )
예제 #15
0
 def _inp_warn(self):
     logger.log_warn(
         'More than one input of the task model is given in SSL_CUTMIX\n'
         'You try to train the task model with more than one input\n'
         'All inputs are preprocessed with a CutMix operation using the same mask\n'
     )
예제 #16
0
파일: proxy.py 프로젝트: lwzbuaa/PixelSSL
    def _create_dataloader(self):
        """ Create data loaders for experiment.
        """

        # ---------------------------------------------------------------------
        # create dataloder for training
        # ---------------------------------------------------------------------

        # ignore_unlabeled == False & unlabeled_batch_size != 0
        #   means that both labeled and unlabeled data are used
        with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0
        # ignore_unlabeled == True & unlabeled_batch_size == 0
        #   means that only the labeled data is used
        without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0

        labeled_train_samples, unlabeled_train_samples = 0, 0
        if not self.args.validation:
            # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error
            if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n'
                )
            # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error
            if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0:
                logger.log_err(
                    'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n'
                )

            # calculate the number of trainsets
            trainset_num = 0
            for key, value in self.args.trainset.items():
                trainset_num += len(value)

            # calculate the number of unlabeledsets
            unlabeledset_num = 0
            for key, value in self.args.unlabeledset.items():
                unlabeledset_num += len(value)

            # if only one labeled training set and without any unlabeled set
            if trainset_num == 1 and unlabeledset_num == 0:
                trainset = self._load_dataset(
                    list(self.args.trainset.keys())[0],
                    list(self.args.trainset.values())[0][0])
                labeled_train_samples = len(trainset.idxs)

                # if the 'sublabeled_path' is given
                sublabeled_prefix = None
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    if not os.path.exists(self.args.sublabeled_path):
                        logger.log_err(
                            'Cannot find labeled file: {0}\n'.format(
                                self.args.sublabeled_path))
                    else:
                        with open(self.args.sublabeled_path) as f:
                            sublabeled_prefix = [
                                line.strip() for line in f.read().splitlines()
                            ]
                        sublabeled_prefix = None if len(
                            sublabeled_prefix) == 0 else sublabeled_prefix

                if sublabeled_prefix is not None:
                    # wrap the trainset by 'SplitUnlabeledWrapper'
                    trainset = nndata.SplitUnlabeledWrapper(
                        trainset,
                        sublabeled_prefix,
                        ignore_unlabeled=self.args.ignore_unlabeled)
                    labeled_train_samples = len(trainset.labeled_idxs)
                    unlabeled_train_samples = len(trainset.unlabeled_idxs)

                # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                    )

            # if more than one labeled training sets are given or the unlabeled training sets are given
            elif trainset_num > 1 or unlabeledset_num > 0:
                # 'arg.sublabeled_path' is disabled
                if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
                    logger.log_err(
                        'Multiple training datasets are given. \n'
                        'Inter-split unlabeled set is not allowed.\n'
                        'Please remove the argument \'sublabeled_path\' in the script\n'
                    )

                # load all training sets
                labeled_sets = []
                for set_name, set_dirs in self.args.trainset.items():
                    for set_dir in set_dirs:
                        labeled_sets.append(
                            self._load_dataset(set_name, set_dir))

                # load all extra unlabeled sets
                unlabeled_sets = []
                # if any extra unlabeled set is given
                if unlabeledset_num > 0:
                    for set_name, set_dirs in self.args.unlabeledset.items():
                        for set_dir in set_dirs:
                            unlabeled_sets.append(
                                self._load_dataset(set_name, set_dir))

                # if unalbeledset_num == 0 but you want to use the unlabeled data for training
                elif with_unlabeled_data:
                    logger.log_err(
                        'Try to use the unlabeled samples without any SSL dataset wrapper\n'
                        'Please add the argument \'unlabeledset\' in the script\n'
                    )

                # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper'
                trainset = nndata.JointDatasetsWrapper(
                    labeled_sets,
                    unlabeled_sets,
                    ignore_unlabeled=self.args.ignore_unlabeled)
                labeled_train_samples = len(trainset.labeled_idxs)
                unlabeled_train_samples = len(trainset.unlabeled_idxs)

            # if use labeled data only
            if without_unlabeled_data:
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=self.args.batch_size,
                    shuffle=True,
                    num_workers=self.args.num_workers,
                    pin_memory=True,
                    drop_last=True)
            # if use both labeled and unlabeled data
            elif with_unlabeled_data:
                train_sampler = nndata.TwoStreamBatchSampler(
                    trainset.labeled_idxs, trainset.unlabeled_idxs,
                    self.args.labeled_batch_size,
                    self.args.unlabeled_batch_size)
                self.train_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_sampler=train_sampler,
                    num_workers=self.args.num_workers,
                    pin_memory=True)

        # ---------------------------------------------------------------------
        # create dataloader for validation
        # ---------------------------------------------------------------------

        # calculate the number of valsets
        valset_num = 0
        for key, value in self.args.valset.items():
            valset_num += len(value)

        # if only one validation set is given
        if valset_num == 1:
            valset = self._load_dataset(list(self.args.valset.keys())[0],
                                        list(self.args.valset.values())[0][0],
                                        is_train=False)
            val_samples = len(valset.idxs)

        # if more than one validation sets are given
        elif valset_num > 1:
            valsets = []
            for set_name, set_dirs in self.args.valset.items():
                for set_dir in set_dirs:
                    valsets.append(
                        self._load_dataset(set_name, set_dir, is_train=False))
            valset = nndata.JointDatasetsWrapper(valsets, [],
                                                 ignore_unlabeled=True)
            val_samples = len(valset.labeled_idxs)

        # NOTE: batch size is set to 1 during the validation
        self.val_loader = torch.utils.data.DataLoader(
            valset,
            batch_size=1,
            shuffle=False,
            num_workers=self.args.num_workers,
            pin_memory=True)

        # check the data loaders
        if self.train_loader is None and not self.args.validation:
            logger.log_err(
                'Train data loader is required if validate mode is closed\n')
        elif self.val_loader is None and self.args.validation:
            logger.log_err(
                'Validate data loader is required if validate mode is opened\n'
            )
        elif self.val_loader is None:
            logger.log_warn(
                'No validate data loader, there are no validation during the training\n'
            )

        # set 'iters_per_epoch', which is required by ITER_LRERS
        self.args.iters_per_epoch = len(
            self.train_loader) if self.train_loader is not None else -1

        logger.log_info(
            'Dataset:\n'
            '  Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n'
            '  Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples,
                                                   unlabeled_train_samples,
                                                   val_samples))