def _inp_warn(self): logger.log_warn( 'More than one ground truth of the task model is given in SSL_NULL\n' 'You try to train the task model with more than one (pred & gt) pairs\n' 'Please make sure that:\n' ' (1) The prediction tuple has the same size as the ground truth tuple\n' ' (2) The elements with the same index in the two tuples are corresponding\n' )
def _algorithm_warn(self): logger.log_warn( 'This SSL_CUTMIX algorith reproduces the SSL algorithm from the paper: \n' ' \'Semi-Supervised Semantic Segmentation Needs Strong, Varied Perturbations\'\n' 'This implementation supports pixel-wise classification only due to the hyper-parameter: \n' ' \'cons-threshold\' \n' 'The \'CutOut\' mode proposed by their paper is not implemented in this code\n' )
def _algorithm_warn(self): logger.log_warn( 'This SSL_S4L algorithm reproducts the SSL algorithm from paper:\n' ' \'S4L: Self-Supervised Semi-Supervised Learning\'\n' 'The main differences between this implementation and the original paper are:\n' ' (1) This is an implementation for pixel-wise vision tasks\n' ' (2) This implementation only supports the 4-angle (0, 90, 180, 270) rotation-based self-supervised pretext task\n' )
def _algorithm_warn(self): logger.log_warn('This SSL_CCT algorithm reproducts the SSL algorithm from paper:\n' ' \'Semi-Supervised Semantic Segmentation with Cross-Consistency Training\'\n' 'The code of the auxiliary decoders are adapted from the official repository:\n' ' https://github.com/yassouali/CCT \n' 'These auxiliary decoders may only suitable for pixel-wise classification\n' 'Hence, this implementation does not currently support pixel-wise regression tasks\n' 'Besides, the auxiliary decoders will use huge GPU memory\n' 'Please reduce the number of the auxiliary decoders if you run out of GPU memory\n')
def _inp_warn(self): logger.log_warn('More than one ground truth of the task model is given in SSL_MT\n' 'You try to train the task model with more than one (pred & gt) pairs\n' 'Please make sure that: \n' ' (1) The prediction tuple has the same size as the ground truth tuple\n' ' (2) The elements with the same index in the two tuples are corresponding\n' ' (3) The first element of (pred & gt) will be used to calculate the consistency constraint\n' 'Please implement a new SSL algorithm if you want a variant of SSL_MT to\n' 'calculate multiple consisteny constraints (for multiple predictions)\n')
def _algorithm_warn(self): logger.log_warn('This SSL_MT algorithm reproducts the SSL algorithm from paper:\n' ' \'Mean Teachers are Better Role Models: Weight-Averaged Consistency Targets ' 'Improve Semi-Supervised Deep Learning Results\'\n' 'The main differences between this implementation and the original paper are:\n' ' (1) This is an implementation for pixel-wise vision tasks\n' ' (2) The two-heads outputs trick is disable in this implementation\n' ' (2) No extra perturbations between the inputs of the teacher and the student\n' ' (The Gaussian noiser is provied, but it will degrade the performance)\n')
def _algorithm_warn(self): logger.log_warn('This SSL_ADV algorithm reproducts the SSL algorithm from paper:\n' ' \'Adversarial Learning for Semi-supervised Semantic Segmentation\'\n' 'The main differences between this implementation and the original paper are:\n' ' (1) This implementation does not support the constraint named \'L_semi\' in the\n' ' original paper since it can only be used for pixel-wise classification\n' '\nThe semi-supervised constraint in this implementation refer to the constraint\n' 'named \'L_adv\' in the original paper\n' '\nSame as the original paper, the FC discriminator is trained by the Adam optimizer\n' 'with the PolynomialLR scheduler\n')
def _inp_warn(self): logger.log_warn( 'More than one ground truth of the task model is given in SSL_S4L\n' 'You try to train the task model with more than one (pred & gt) pairs\n' 'Please make sure that:\n' ' (1) The prediction tuple has the same size as the ground truth tuple\n' ' (2) The elements with the same index in the two tuples are corresponding\n' ' (3) All elements in the ground truth tuple should be 4-dim tensors since S4L\n' ' will rotate them to match the rotated inputs\n' 'Please implement a new SSL algorithm if you want a variant of SSL_S4L that\n' 'supports other formants (not 4-dim tensor) of the ground truth\n')
def visualize(self, out_path, id_str='', inp=None, pred=None, gt=None): """ Visualize images during training/validation. Please refer to the implemented tasks to finish this function. Arguments: out_path (str): path to save the images id_str (str): identifier for recording inp (tuple): inputs of the task model pred (tuple): prediction of the task model gt (tuple): ground truth of the task model """ logger.log_warn('No implementation of the \'visulize\' function for current task.\n' 'Please implement it in \'task/xxx/func.py\'.\n')
def metrics(self, pred, gt, inp, meters, id_str=''): """ Calculate metrics for the task model. This function calculates all performance metrics for the task model and saves them into 'meters' (with prefix '[id_str]-[str_METRIC_STR]_xxx'). Arguments: pred (torch.Tensor): prediction of the task model gt (torch.Tensor): ground truth of the task model meters (pixelssl.utils.AvgMeterSet): recorder id_str (str): identifier for recording """ logger.log_warn('No implementation of the \'metrics\' function for current task.\n' 'Please implement it in \'task/xxx/func.py\'.\n')
def step(self, epoch=None): if epoch is not None and epoch != 0: # update lr after each epoch if epoch is given # after each epoch, set epoch += 1 and call this function if not self.is_warn: logger.log_warn( 'PolynomialLR is designed for updating learning rate after each iteration.\n' 'However, it will be updated after each epoch now, please be careful.\n' ) self.is_warn = True self.last_epoch = epoch assert self.last_epoch <= self.epochs self.cur_iter = self.last_epoch * self.iters_per_epoch elif epoch is None: # update lr after each iteration if epoch is None self.cur_iter += 1 self.last_epoch = math.floor(self.cur_iter / self.iters_per_epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr
def _algorithm_warn(self): logger.log_warn( 'This SSL_NULL algorithm is a fully-supervised baseline for SSL.\n' )
def _data_err(self): logger.log_warn('More than one ground truth of the task model is given in SSL_CCT\n' 'Currently, this implementation of CCT algorithm supports only one (pred & gt) pairs\n' 'Please implement a new SSL algorithm if you want a variant of SSL_CCT that\n' 'supports more than one (pred & gt) pairs\n')
def _gt_warn(self): logger.log_warn( 'More than one ground truth of the task model is given in SSL_CUTMIX\n' 'You try to train the task model with more than one ground truth\n' 'All ground truths are preprocessed with a CutMix operation using the same mask\n' )
def _inp_warn(self): logger.log_warn( 'More than one input of the task model is given in SSL_CUTMIX\n' 'You try to train the task model with more than one input\n' 'All inputs are preprocessed with a CutMix operation using the same mask\n' )
def _create_dataloader(self): """ Create data loaders for experiment. """ # --------------------------------------------------------------------- # create dataloder for training # --------------------------------------------------------------------- # ignore_unlabeled == False & unlabeled_batch_size != 0 # means that both labeled and unlabeled data are used with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0 # ignore_unlabeled == True & unlabeled_batch_size == 0 # means that only the labeled data is used without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 labeled_train_samples, unlabeled_train_samples = 0, 0 if not self.args.validation: # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n' ) # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n' ) # calculate the number of trainsets trainset_num = 0 for key, value in self.args.trainset.items(): trainset_num += len(value) # calculate the number of unlabeledsets unlabeledset_num = 0 for key, value in self.args.unlabeledset.items(): unlabeledset_num += len(value) # if only one labeled training set and without any unlabeled set if trainset_num == 1 and unlabeledset_num == 0: trainset = self._load_dataset( list(self.args.trainset.keys())[0], list(self.args.trainset.values())[0][0]) labeled_train_samples = len(trainset.idxs) # if the 'sublabeled_path' is given sublabeled_prefix = None if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': if not os.path.exists(self.args.sublabeled_path): logger.log_err( 'Cannot find labeled file: {0}\n'.format( self.args.sublabeled_path)) else: with open(self.args.sublabeled_path) as f: sublabeled_prefix = [ line.strip() for line in f.read().splitlines() ] sublabeled_prefix = None if len( sublabeled_prefix) == 0 else sublabeled_prefix if sublabeled_prefix is not None: # wrap the trainset by 'SplitUnlabeledWrapper' trainset = nndata.SplitUnlabeledWrapper( trainset, sublabeled_prefix, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' ) # if more than one labeled training sets are given or the unlabeled training sets are given elif trainset_num > 1 or unlabeledset_num > 0: # 'arg.sublabeled_path' is disabled if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': logger.log_err( 'Multiple training datasets are given. \n' 'Inter-split unlabeled set is not allowed.\n' 'Please remove the argument \'sublabeled_path\' in the script\n' ) # load all training sets labeled_sets = [] for set_name, set_dirs in self.args.trainset.items(): for set_dir in set_dirs: labeled_sets.append( self._load_dataset(set_name, set_dir)) # load all extra unlabeled sets unlabeled_sets = [] # if any extra unlabeled set is given if unlabeledset_num > 0: for set_name, set_dirs in self.args.unlabeledset.items(): for set_dir in set_dirs: unlabeled_sets.append( self._load_dataset(set_name, set_dir)) # if unalbeledset_num == 0 but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' 'Please add the argument \'unlabeledset\' in the script\n' ) # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper' trainset = nndata.JointDatasetsWrapper( labeled_sets, unlabeled_sets, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if use labeled data only if without_unlabeled_data: self.train_loader = torch.utils.data.DataLoader( trainset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, pin_memory=True, drop_last=True) # if use both labeled and unlabeled data elif with_unlabeled_data: train_sampler = nndata.TwoStreamBatchSampler( trainset.labeled_idxs, trainset.unlabeled_idxs, self.args.labeled_batch_size, self.args.unlabeled_batch_size) self.train_loader = torch.utils.data.DataLoader( trainset, batch_sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True) # --------------------------------------------------------------------- # create dataloader for validation # --------------------------------------------------------------------- # calculate the number of valsets valset_num = 0 for key, value in self.args.valset.items(): valset_num += len(value) # if only one validation set is given if valset_num == 1: valset = self._load_dataset(list(self.args.valset.keys())[0], list(self.args.valset.values())[0][0], is_train=False) val_samples = len(valset.idxs) # if more than one validation sets are given elif valset_num > 1: valsets = [] for set_name, set_dirs in self.args.valset.items(): for set_dir in set_dirs: valsets.append( self._load_dataset(set_name, set_dir, is_train=False)) valset = nndata.JointDatasetsWrapper(valsets, [], ignore_unlabeled=True) val_samples = len(valset.labeled_idxs) # NOTE: batch size is set to 1 during the validation self.val_loader = torch.utils.data.DataLoader( valset, batch_size=1, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) # check the data loaders if self.train_loader is None and not self.args.validation: logger.log_err( 'Train data loader is required if validate mode is closed\n') elif self.val_loader is None and self.args.validation: logger.log_err( 'Validate data loader is required if validate mode is opened\n' ) elif self.val_loader is None: logger.log_warn( 'No validate data loader, there are no validation during the training\n' ) # set 'iters_per_epoch', which is required by ITER_LRERS self.args.iters_per_epoch = len( self.train_loader) if self.train_loader is not None else -1 logger.log_info( 'Dataset:\n' ' Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n' ' Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples, unlabeled_train_samples, val_samples))