def _pred_err(self): logger.log_err( 'In SSL_MT, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n')
def __init__(self, args): super(SSLMT, self).__init__(args) # define the student model and the teacher model self.s_model, self.t_model = None, None self.s_optimizer = None self.s_lrer = None self.s_criterion = None self.cons_criterion = None self.gaussian_noiser = None self.zero_tensor = torch.zeros(1) # check SSL arguments if self.args.cons_for_labeled or self.args.unlabeled_batch_size > 0: if self.args.cons_scale < 0: logger.log_err( 'The argument - cons_scale - is not set (or invalid)\n' 'You set argument - cons_for_labeled - to True\n' 'or\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - cons_scale >= 0 - for training\n') if self.args.cons_rampup_epochs < 0: logger.log_err( 'The argument - cons_rampup_epochs - is not set (or invalid)\n' 'You set argument - cons_for_labeled - to True\n' 'or\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - cons_rampup_epochs >= 0 - for training\n')
def forward(self, inp): resulter, debugger = {}, {} t_resulter, t_debugger = self.task_model.forward(inp) if not 'pred' in t_resulter.keys( ) or not 'activated_pred' in t_resulter.keys(): logger.log_err( 'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n') if not 'ssls4l_rc_inp' in t_resulter.keys(): logger.log_err( 'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the key:\n' ' \'ssls4l_rc_inp\'\t=>\tinputs of the rotation classifier (a 4-dim tensor)\n' 'It can be the feature map encoded by the task model or the output of the task model\n' 'Please add the key \'ssls4l_rc_inp\' in your task model\'s resulter\n' ) rc_inp = tool.dict_value(t_resulter, 'ssls4l_rc_inp') pred_rotation = self.rotation_classifier.forward(rc_inp) resulter['pred'] = tool.dict_value(t_resulter, 'pred') resulter['activated_pred'] = tool.dict_value(t_resulter, 'activated_pred') resulter['rotation'] = pred_rotation return resulter, debugger
def __init__(self, args): super(SSLADV, self).__init__(args) # define the task model and the FC discriminator self.model, self.d_model = None, None self.optimizer, self.d_optimizer = None, None self.lrer, self.d_lrer = None, None self.criterion, self.d_criterion = None, None # prepare the arguments for multiple GPUs self.args.discriminator_lr *= self.args.gpus # check SSL arguments if self.args.adv_for_labeled: if self.args.labeled_adv_scale < 0: logger.log_err('The argument - labeled_adv_scale - is not set (or invalid)\n' 'You set argument - adv_for_labeled - to True\n' 'Please set - labeled_adv_scale >= 0 - for calculating the' 'adversarial loss on the labeled data\n') if self.args.unlabeled_batch_size > 0: if self.args.unlabeled_adv_scale < 0: logger.log_err('The argument - unlabeled_adv_scale - is not set (or invalid)\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - unlabeled_adv_scale >= 0 - for calculating the' 'adversarial loss on the unlabeled data\n')
def _build_ssl_algorithm(self): """ Build the semi-supervised learning algorithm given in the script. """ for cname in self.args.models.keys(): self.model_dict[cname] = self.model.__dict__[ self.args.models[cname]]() self.criterion_dict[cname] = self.criterion.__dict__[ self.args.criterions[cname]]() self.lrer_dict[cname] = nnlrer.__dict__[self.args.lrers[cname]]( self.args) self.optimizer_dict[cname] = nnoptimizer.__dict__[ self.args.optimizers[cname]](self.args) logger.log_info('SSL algorithm: \n {0}\n'.format( self.args.ssl_algorithm)) logger.log_info('Models: ') self.ssl_algorithm = pixelssl.ssl_algorithm.__dict__[ self.args.ssl_algorithm].__dict__[self.args.ssl_algorithm]( self.args, self.model_dict, self.optimizer_dict, self.lrer_dict, self.criterion_dict, self.func.task_func()(self.args)) # check whether the SSL algorithm supports the given task if not self.TASK_TYPE in self.ssl_algorithm.SUPPORTED_TASK_TYPES: logger.log_err( 'SSL algorithm - {0} - supports task types {1}\n' 'However, the given task - {2} - belongs to {3}\n'.format( self.ssl_algorithm.NAME, self.ssl_algorithm.SUPPORTED_TASK_TYPES, self.args.task, self.TASK_TYPE))
def pytorch_support(required_version='1.0.0', info_str=''): if torch.__version__ < required_version: logger.log_err('{0} required PyTorch >= {1}\n' 'However, current PyTorch == {2}\n'.format( info_str, required_version, torch.__version__)) else: return True
def _batch_prehandle(self, inp, gt, is_train): # add extra data augmentation process here if necessary inp_var = [] for i in inp: inp_var.append(Variable(i).cuda()) inp = tuple(inp_var) gt_var = [] for g in gt: gt_var.append(Variable(g).cuda()) gt = tuple(gt_var) mix_u_inp = None mix_u_mask = None # ------------------------------------------------- # Operations for CUTMIX # ------------------------------------------------- if is_train: lbs = self.args.labeled_batch_size ubs = self.args.unlabeled_batch_size # check the shape of input and gt # NOTE: this implementation of CUTMIX supports multiple input and gt # but all input and gt should have the same image size sample_shape = (inp[0].shape[2], inp[0].shape[3]) for i in inp: if not tuple(i.shape[2:]) == sample_shape: logger.log_err( 'This SSL_CUTMIX algorithm requires all inputs have the same shape \n' ) for g in gt: if not tuple(g.shape[2:]) == sample_shape: logger.log_err( 'This SSL_CUTMIX algorithm requires all ground truths have the same shape \n' ) # generate the mask for mixing the unlabeled samples mix_u_mask = self.mask_generator.produce(int(ubs / 2), sample_shape) mix_u_mask = torch.tensor(mix_u_mask).cuda() # mix the unlabeled samples u_inp_1 = func.split_tensor_tuple(inp, lbs, int(lbs + ubs / 2)) u_inp_2 = func.split_tensor_tuple(inp, int(lbs + ubs / 2), self.args.batch_size) mix_u_inp = [] for ui_1, ui_2 in zip(u_inp_1, u_inp_2): mi = mix_u_mask * ui_1 + (1 - mix_u_mask) * ui_2 mix_u_inp.append(mi) mix_u_inp = tuple(mix_u_inp) return inp, gt, mix_u_inp, mix_u_mask
def _run(self): """ Main pipeline of experiment. Please override this function if you want a special pipeline. """ start_epoch = 0 if self.args.resume is not None and self.args.resume != '': logger.log_info('Load checkpoint from: {0}'.format( self.args.resume)) start_epoch = self.ssl_algorithm.load_checkpoint() if self.args.validation: if self.val_loader is None: logger.log_err('No data loader for validation.\n' 'Please set right \'valset\' in the script.\n') logger.log_info( ['=' * 78, '\nStart to validate model\n', '=' * 78]) with torch.no_grad(): self.ssl_algorithm.validate(self.val_loader, start_epoch) return # NOTE: the first epoch index for 'train' and 'validatie' is 0 for epoch in range(start_epoch, self.args.epochs): timer = time.time() logger.log_info([ '=' * 78, '\nStart to train epoch-{0}\n'.format(epoch + 1), '=' * 78 ]) self.ssl_algorithm.train(self.train_loader, epoch) if (epoch + 1 ) % self.args.val_freq == 0 and self.val_loader is not None: logger.log_info([ '=' * 78, '\nStart to validate epoch-{0}\n'.format(epoch + 1), '=' * 78 ]) with torch.no_grad(): self.ssl_algorithm.validate(self.val_loader, epoch) if (epoch + 1) % self.args.checkpoint_freq == 0: self.ssl_algorithm.save_checkpoint(epoch + 1) logger.log_info("Save checkpoint for epoch {0}".format(epoch + 1)) logger.log_info( 'Finish epoch in {0} seconds\n'.format(time.time() - timer)) logger.log_info('Finish experiment {0}\n'.format(self.args.exp_id))
def __init__(self, args): super(SSLCUTMIX, self).__init__(args) self.s_model, self.t_model = None, None self.s_optimizer = None self.s_lrer = None self.s_criterion = None # define the auxiliary modules required by CUTMIX self.mask_generator = None # check SSL arguments if self.args.unlabeled_batch_size > 0: if not self.args.unlabeled_batch_size > 2 or not self.args.unlabeled_batch_size % 2 == 0: logger.log_err( 'This implementation of SSL_CUTMIX requires the unlabeled batch size: \n' ' 1. larger than 2 \n' ' 2. is divisible by 2 \n') if self.args.cons_scale < 0: logger.log_err( 'The argument - cons_scale - is not set (or invalid)\n' 'Please set - cons_scale >= 0 - for training\n') if self.args.cons_rampup_epochs < 0: logger.log_err( 'The argument - cons_rampup_epochs - is not set (or invalid)\n' 'Please set - cons_rampup_epochs >= 0 - for training\n') if self.args.cons_threshold < 0 or self.args.cons_threshold > 1: logger.log_err( 'The argument - cons_threshold - is not set (or invalid)\n' 'Please set - 0 <= cons_threshold < 1 - for training\n')
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err('Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.s_model.load_state_dict(checkpoint['s_model']) self.t_model.load_state_dict(checkpoint['t_model']) self.s_optimizer.load_state_dict(checkpoint['s_optimizer']) self.s_lrer.load_state_dict(checkpoint['s_lrer']) return checkpoint['epoch']
def create_parser(algorithm): parser = argparse.ArgumentParser( description='PixelSSL Static Script Parser') if not algorithm in ssl_algorithm.SSL_ALGORITHMS: logger.log_err('Unknown semi-supervised learning algorithm: {0}\n' 'The support algorithms are: {1}\n'.format( algorithm, ssl_algorithm.SSL_ALGORITHMS)) optimizer.add_parser_arguments(parser) lrer.add_parser_arguments(parser) ssl_algorithm.__dict__[algorithm].add_parser_arguments(parser) return parser
def ssl_cct(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func): if not len(model_dict) == len(optimizer_dict) == len(lrer_dict) == len(criterion_dict) == 1: logger.log_err('The len(element_dict) of SSL_CCT should be 1\n') elif list(model_dict.keys())[0] != 'model': logger.log_err('In SSL_CCT, the key of element_dict should be \'model\',\n' 'but \'{0}\' is given\n'.format(model_dict.keys())) model_funcs = [model_dict['model']] optimizer_funcs = [optimizer_dict['model']] lrer_funcs = [lrer_dict['model']] criterion_funcs = [criterion_dict['model']] algorithm = SSLCCT(args) algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func) return algorithm
def _load_dataset(self, dataset_name, dataset_dir, is_train=True): """ Load one dataset. """ if not dataset_name in self.data.__dict__.keys(): logger.log_err('Unknown dataset type: {0}\n'.format(dataset_name)) elif not os.path.exists(dataset_dir): logger.log_err('Cannot find the path of dataset: {0}\n'.format(dataset_dir)) else: dataset_args = copy.deepcopy(self.args) if is_train: dataset_args.trainset = {dataset_name: dataset_dir} else: dataset_args.valset = {dataset_name: dataset_dir} return self.data.__dict__[dataset_name]()(dataset_args, is_train)
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err('Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lrer.load_state_dict(checkpoint['lrer']) self.main_model = self.model.module.main_model self.auxiliary_decoders = self.model.module.auxiliary_decoders return checkpoint['epoch']
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err( 'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lrer.load_state_dict(checkpoint['lrer']) self.task_model = self.model.module.task_model self.rotation_classifier = self.model.module.rotation_classifier return checkpoint['epoch']
def forward(self, x): """ Arguments: x (torch.Tensor): input 4D tensor Returns: torch.Tensor: Blurred version of the input """ if not len(list(x.shape)) == 4: logger.log_err( '\'GaussianBlurLayer\' requires a 4D tensor as input\n') elif not x.shape[1] == self.channels: logger.log_err( 'In \'GaussianBlurLayer\', the required channel ({0}) is' 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) return self.op(x)
def __init__(self, args): super(SSLS4L, self).__init__(args) self.task_model = None self.rotation_classifier = None self.model = None self.optimizer = None self.lrer = None self.criterion = None # check SSL arguments if self.args.rotation_scale < 0: logger.log_err( 'The argument - rotation_scale - is not set (or invalid)\n' 'Please set - rotation_scale >= 0 - for training\n') if self.args.rotated_sup_scale < 0: logger.log_err( 'The argument - rotated_sup_scale - is not set (or invalid)\n' 'Please set - rotated_sup_scale >= 0 - for training\n')
def __init__(self, args): super(SSLCCT, self).__init__(args) self.main_model = None self.auxiliary_decoders = None self.model = None self.optimizer = None self.lrer = None self.criterion = None self.cons_criterion = None # check SSL arguments if self.args.unlabeled_batch_size > 0: if self.args.cons_scale < 0: logger.log_err('The argument - cons_scale - is not set (or invalid)\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - cons_scale >= 0 - for training\n') elif self.args.cons_rampup_epochs < 0: logger.log_err('The argument - cons_rampup_epochs - is not set (or invalid)\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - cons_rampup_epochs >= 0 - for training\n') if self.args.ad_lr_scale < 0: logger.log_err('The argument - ad_lr_scale - is not set (or invalid)\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - ad_lr_scale >= 0 - for training\n') else: self.args.ad_lr_scale = 0
def forward(self, inp, gt, is_unlabeled): resulter, debugger = {}, {} # forward the task model m_resulter, m_debugger = self.main_model.forward(inp) if not 'pred' in m_resulter.keys() or not 'activated_pred' in m_resulter.keys(): logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n') resulter['pred'] = tool.dict_value(m_resulter, 'pred') resulter['activated_pred'] = tool.dict_value(m_resulter, 'activated_pred') if not len(resulter['pred']) == len(resulter['activated_pred']) == 1: logger.log_err('This implementation of SSL_CCT only support the task model with only one prediction (output). \n' 'However, there are {0} predictions.\n'.format(len(resulter['pred']))) # calculate the task loss resulter['task_loss'] = None if is_unlabeled else torch.mean(self.task_criterion.forward(resulter['pred'], gt, inp)) # for the unlabeled data if is_unlabeled and self.args.unlabeled_batch_size > 0: if not 'sslcct_ad_inp' in m_resulter.keys(): logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the key:\n' ' \'sslcct_ad_inp\'\t=>\tinputs of the auxiliary decoders (a 4-dim tensor)\n' 'It is the feature map encoded by the task model\n' 'Please add the key \'sslcct_ad_inp\' in your task model\'s resulter\n' 'Note that for different task models, the shape of \'sslcct_ad_inp\' may be different\n') ul_ad_inp = tool.dict_value(m_resulter, 'sslcct_ad_inp') ul_main_pred = resulter['pred'][0].detach() # forward the auxiliary decoders ul_ad_preds = [] for ad in self.auxiliary_decoders: ul_ad_preds.append(ad.forward(ul_ad_inp, pred_of_main_decoder=ul_main_pred)) resulter['ul_ad_preds'] = ul_ad_preds # calculate the consistency loss ul_ad_gt = resulter['activated_pred'][0].detach() ul_ad_preds = [F.interpolate(ul_ad_pred, size=(ul_ad_gt.shape[2], ul_ad_gt.shape[3]), mode='bilinear') for ul_ad_pred in ul_ad_preds] ul_activated_ad_preds = self.ad_activation_func(ul_ad_preds) cons_loss = sum([self.cons_criterion.forward(ul_activated_ad_pred, ul_ad_gt) for ul_activated_ad_pred in ul_activated_ad_preds]) cons_loss = torch.mean(cons_loss) / len(ul_activated_ad_preds) resulter['cons_loss'] = cons_loss else: resulter['ul_ad_preds'] = None resulter['cons_loss'] = None return resulter, debugger
def ssl_gct(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func): if not len(model_dict) == len(optimizer_dict) == len(lrer_dict) == len( criterion_dict): logger.log_err('The len(element_dict) of SSL_GCT should be the same\n') if len(model_dict) == 1: if list(model_dict.keys())[0] != 'model': logger.log_err( 'In SSL_GCT, the key of 1-value element_dict should be \'model\',\n' 'but \'{0}\' is given\n'.format(model_dict.keys())) model_funcs = [model_dict['model'], model_dict['model']] optimizer_funcs = [optimizer_dict['model'], optimizer_dict['model']] lrer_funcs = [lrer_dict['model'], lrer_dict['model']] criterion_funcs = [criterion_dict['model'], criterion_dict['model']] elif len(model_dict) == 2: if 'lmodel' not in list(model_dict.keys()) or 'rmodel' not in list( model_dict.keys()): logger.log_err( 'In SSL_GCT, the key of 2-value element_dict should be \'(lmodel, rmodel)\', ' 'but \'{0}\' is given\n'.format(model_dict.keys())) model_funcs = [model_dict['lmodel'], model_dict['rmodel']] optimizer_funcs = [optimizer_dict['lmodel'], optimizer_dict['rmodel']] lrer_funcs = [lrer_dict['lmodel'], lrer_dict['rmodel']] criterion_funcs = [criterion_dict['lmodel'], criterion_dict['rmodel']] else: logger.log_err( 'The SSL_GCT algorithm supports element_dict with 1 or 2 elements, ' 'but given {0} elements\n'.format(len(model_dict))) algorithm = SSLGCT(args) algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func) return algorithm
def _preprocess_arguments(self): """ Preprocess the arguments in the script. """ # create the output folder to store the results self.args.out_path = "{root}/{exp_id}/{date:%Y-%m-%d_%H:%M:%S}/".format( root=self.args.out_path, exp_id=self.args.exp_id, date=datetime.now()) if not os.path.exists(self.args.out_path): os.makedirs(self.args.out_path) # prepare logger exp_op = 'val' if self.args.validation else 'train' logger.log_mode(self.args.debug) logger.log_file( os.path.join(self.args.out_path, '{0}.log'.format(exp_op)), self.args.debug) logger.log_info('Result folder: \n {0} \n'.format(self.args.out_path)) # print experimental args cmd.print_args() # set task name self.args.task = self.NAME # check the task-specific components dicts required by the SSL algorithm if not len(self.args.models) == len(self.args.optimizers) == len( self.args.lrers) == len(self.args.criterions): logger.log_err( 'Condition:\n' '\tlen(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions\n' 'is not satisfied in the script\n') for (model, criterion, optimizer, lrer) in \ zip(self.args.models.values(), self.args.criterions.values(), self.args.optimizers.values(), self.args.lrers.values()): if model not in self.model.__dict__: logger.log_err( 'Unsupport model: {0} for task: {1}\n' 'Please add the export function in task\'s \'model.py\'\n'. format(model, self.args.task)) elif criterion not in self.criterion.__dict__: logger.log_err( 'Unsupport criterion: {0} for task: {1}\n' 'Please add the export function in task\'s \'criterion.py\'\n' .format(criterion, self.args.task)) elif optimizer not in nnoptimizer.__dict__: logger.log_err( 'Unsupport optimizer: {0}\n' 'Please implement the optimizer wrapper in \'pixelssl/nn/optimizer.py\'\n' .format(optimizer)) elif lrer not in nnlrer.__dict__: logger.log_err( 'Unsupport learning rate scheduler: {0}\n' 'Please implement lr scheduler wrapper in \'pixelssl/nn/lrer.py\'\n' .format(lrer)) # check the types of lrers for lrer in self.args.lrers.values(): if lrer in nnlrer.EPOCH_LRERS: is_epoch_lrer = True elif lrer in nnlrer.ITER_LRERS: is_epoch_lrer = False else: logger.log_err( 'Unknown learning rate scheduler ({0}) type\n' 'Please add it into either EPOCH_LRERS or ITER_LRERS in \'pixelssl/nn/lrer.py\'\n' 'PixelSSL supports: \n' ' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'.format( lrer, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS)) if self.args.is_epoch_lrer is None: self.args.is_epoch_lrer = is_epoch_lrer elif self.args.is_epoch_lrer != is_epoch_lrer: logger.log_err( 'Unmatched lr scheduler types\t=>\t{0}\n' 'All lrers of the task models should have the same types (either EPOCH_LRERS or ITER_LRERS)\n' 'PixelSSL supports: \n' ' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'.format( self.args.lrers, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS)) self.args.checkpoint_path = os.path.join(self.args.out_path, 'ckpt') if not os.path.exists(self.args.checkpoint_path): os.makedirs(self.args.checkpoint_path) if self.args.visualize: self.args.visual_debug_path = os.path.join(self.args.out_path, 'visualization/debug') self.args.visual_train_path = os.path.join(self.args.out_path, 'visualization/train') self.args.visual_val_path = os.path.join(self.args.out_path, 'visualization/val') for vpath in [ self.args.visual_debug_path, self.args.visual_train_path, self.args.visual_val_path ]: if not os.path.exists(vpath): os.makedirs(vpath) # handle argumens for multiply GPUs training self.args.gpus = torch.cuda.device_count() if self.args.gpus < 1: logger.log_err('No GPU be detected\n' 'PixelSSL requires at least one Nvidia GPU\n') logger.log_info('GPU: \n Total GPU(s): {0}'.format(self.args.gpus)) self.args.lr *= self.args.gpus self.args.num_workers *= self.args.gpus self.args.batch_size *= self.args.gpus self.args.unlabeled_batch_size *= self.args.gpus # TODO: support unsupervised and self-supervised training if self.args.unlabeled_batch_size >= self.args.batch_size: logger.log_err( 'The argument \'unlabeled_batch_size\' ({0}) should be smaller than \'batch_size\' ({1}) ' 'since PixelSSL only supports semi-supervised learning now\n') self.args.labeled_batch_size = self.args.batch_size - self.args.unlabeled_batch_size logger.log_info( ' Total learn rate: {0}\n Total labeled batch size: {1}\n' ' Total unlabeled batch size: {2}\n Total data workers: {3}\n'. format(self.args.lr, self.args.labeled_batch_size, self.args.unlabeled_batch_size, self.args.num_workers))
def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt, fc_mask, dc_rampup_scale): if mid == 'l': model, criterion = self.l_model, self.l_criterion elif mid == 'r': model, criterion = self.r_model, self.r_criterion else: model, criterion = None, None # forward the task model resulter, debugger = model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( inp, activated_pred[0]) flawmap = tool.dict_value(fd_resulter, 'flawmap') # calculate the supervised task constraint on the labeled data labeled_pred = func.split_tensor_tuple(pred, 0, lbs) labeled_gt = func.split_tensor_tuple(gt, 0, lbs) labeled_inp = func.split_tensor_tuple(inp, 0, lbs) task_loss = torch.mean( criterion.forward(labeled_pred, labeled_gt, labeled_inp)) self.meters.update('{0}_task_loss'.format(mid), task_loss.data) # calculate the flaw correction constraint if self.args.ssl_mode in [MODE_GCT, MODE_FC]: if flawmap.shape == self.zero_df_gt.shape: fc_ssl_loss = self.fd_criterion.forward(flawmap, self.zero_df_gt, is_ssl=True, reduction=False) else: fc_ssl_loss = self.fd_criterion.forward( flawmap, torch.zeros(flawmap.shape).cuda(), is_ssl=True, reduction=False) if self.args.ssl_mode == MODE_GCT: fc_ssl_loss = fc_mask * fc_ssl_loss fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss) self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data) else: fc_ssl_loss = 0 self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss) # calculate the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: if dc_gt is None: logger.log_err( 'The dynamic consistency constraint is enabled, ' 'but no pseudo ground truth is given.') dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt) dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean( dc_ssl_loss) self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data) else: dc_ssl_loss = 0 self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss) with torch.no_grad(): flawmap_gt = self.fdgt_generator.forward( activated_pred[0], self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0])) # for validation if not is_train: fd_loss = self.args.fd_scale * self.fd_criterion.forward( flawmap, flawmap_gt) self.meters.update('{0}_fd_loss'.format(mid), torch.mean(fd_loss).data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str=mid) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: with torch.no_grad(): handled_flawmap = self.flawmap_handler(flawmap)[0] self._visualize( epoch, idx, is_train, mid, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), handled_flawmap, flawmap_gt[0], dc_gt[0]) loss = task_loss + fc_ssl_loss + dc_ssl_loss return loss
def _create_dataloader(self): """ Create data loaders for experiment. """ # --------------------------------------------------------------------- # create dataloder for training # --------------------------------------------------------------------- # ignore_unlabeled == False & unlabeled_batch_size != 0 # means that both labeled and unlabeled data are used with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0 # ignore_unlabeled == True & unlabeled_batch_size == 0 # means that only the labeled data is used without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 labeled_train_samples, unlabeled_train_samples = 0, 0 if not self.args.validation: # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n' ) # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n' ) # calculate the number of trainsets trainset_num = 0 for key, value in self.args.trainset.items(): trainset_num += len(value) # calculate the number of unlabeledsets unlabeledset_num = 0 for key, value in self.args.unlabeledset.items(): unlabeledset_num += len(value) # if only one labeled training set and without any unlabeled set if trainset_num == 1 and unlabeledset_num == 0: trainset = self._load_dataset( list(self.args.trainset.keys())[0], list(self.args.trainset.values())[0][0]) labeled_train_samples = len(trainset.idxs) # if the 'sublabeled_path' is given sublabeled_prefix = None if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': if not os.path.exists(self.args.sublabeled_path): logger.log_err( 'Cannot find labeled file: {0}\n'.format( self.args.sublabeled_path)) else: with open(self.args.sublabeled_path) as f: sublabeled_prefix = [ line.strip() for line in f.read().splitlines() ] sublabeled_prefix = None if len( sublabeled_prefix) == 0 else sublabeled_prefix if sublabeled_prefix is not None: # wrap the trainset by 'SplitUnlabeledWrapper' trainset = nndata.SplitUnlabeledWrapper( trainset, sublabeled_prefix, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' ) # if more than one labeled training sets are given or the unlabeled training sets are given elif trainset_num > 1 or unlabeledset_num > 0: # 'arg.sublabeled_path' is disabled if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': logger.log_err( 'Multiple training datasets are given. \n' 'Inter-split unlabeled set is not allowed.\n' 'Please remove the argument \'sublabeled_path\' in the script\n' ) # load all training sets labeled_sets = [] for set_name, set_dirs in self.args.trainset.items(): for set_dir in set_dirs: labeled_sets.append( self._load_dataset(set_name, set_dir)) # load all extra unlabeled sets unlabeled_sets = [] # if any extra unlabeled set is given if unlabeledset_num > 0: for set_name, set_dirs in self.args.unlabeledset.items(): for set_dir in set_dirs: unlabeled_sets.append( self._load_dataset(set_name, set_dir)) # if unalbeledset_num == 0 but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' 'Please add the argument \'unlabeledset\' in the script\n' ) # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper' trainset = nndata.JointDatasetsWrapper( labeled_sets, unlabeled_sets, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if use labeled data only if without_unlabeled_data: self.train_loader = torch.utils.data.DataLoader( trainset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, pin_memory=True, drop_last=True) # if use both labeled and unlabeled data elif with_unlabeled_data: train_sampler = nndata.TwoStreamBatchSampler( trainset.labeled_idxs, trainset.unlabeled_idxs, self.args.labeled_batch_size, self.args.unlabeled_batch_size) self.train_loader = torch.utils.data.DataLoader( trainset, batch_sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True) # --------------------------------------------------------------------- # create dataloader for validation # --------------------------------------------------------------------- # calculate the number of valsets valset_num = 0 for key, value in self.args.valset.items(): valset_num += len(value) # if only one validation set is given if valset_num == 1: valset = self._load_dataset(list(self.args.valset.keys())[0], list(self.args.valset.values())[0][0], is_train=False) val_samples = len(valset.idxs) # if more than one validation sets are given elif valset_num > 1: valsets = [] for set_name, set_dirs in self.args.valset.items(): for set_dir in set_dirs: valsets.append( self._load_dataset(set_name, set_dir, is_train=False)) valset = nndata.JointDatasetsWrapper(valsets, [], ignore_unlabeled=True) val_samples = len(valset.labeled_idxs) # NOTE: batch size is set to 1 during the validation self.val_loader = torch.utils.data.DataLoader( valset, batch_size=1, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) # check the data loaders if self.train_loader is None and not self.args.validation: logger.log_err( 'Train data loader is required if validate mode is closed\n') elif self.val_loader is None and self.args.validation: logger.log_err( 'Validate data loader is required if validate mode is opened\n' ) elif self.val_loader is None: logger.log_warn( 'No validate data loader, there are no validation during the training\n' ) # set 'iters_per_epoch', which is required by ITER_LRERS self.args.iters_per_epoch = len( self.train_loader) if self.train_loader is not None else -1 logger.log_info( 'Dataset:\n' ' Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n' ' Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples, unlabeled_train_samples, val_samples))
def __init__(self, args): super(SSLGCT, self).__init__(args) # define the task model and the flaw detector self.l_model, self.r_model, self.fd_model = None, None, None self.l_optimizer, self.r_optimizer, self.fd_optimizer = None, None, None self.l_lrer, self.r_lrer, self.fd_lrer = None, None, None self.l_criterion, self.r_criterion, self.fd_criterion = None, None, None # define the extra modules required by GCT self.flawmap_handler = None self.dcgt_generator = None self.fdgt_generator = None self.zero_df_gt = torch.zeros( [self.args.batch_size, 1, self.args.im_size, self.args.im_size]).cuda() # prepare the arguments for multiple GPUs self.args.fd_lr *= self.args.gpus # check SSL arguments if self.args.unlabeled_batch_size > 0: if self.args.ssl_mode in [MODE_GCT, MODE_FC]: if self.args.fc_ssl_scale < 0: logger.log_err( 'The argument - fc_ssl_scale - is not set (or invalid)\n' 'You enable the flaw correction constraint\n' 'Please set - fc_ssl_scale >= 0 - for training\n') if self.args.ssl_mode in [MODE_GCT, MODE_DC]: if self.args.dc_rampup_epochs < 0: logger.log_err( 'The argument - dc_rampup_epochs - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_rampup_epochs >= 0 - for training\n') elif self.args.dc_ssl_scale < 0: logger.log_err( 'The argument - dc_ssl_scale - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_ssl_scale >= 0 - for training\n') elif self.args.dc_threshold < 0: logger.log_err( 'The argument - dc_threshold - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_threshold >= 0 - for training\n') elif self.args.mu < 0: logger.log_err( 'The argument - mu - is not set (or invalid)\n' 'Please set - 0 < mu <= 1 - for training\n') elif self.args.nu < 0: logger.log_err( 'The argument - nu - is not set (or invalid)\n' 'Please set - nu > 0 - for training\n')
def _train(self, data_loader, epoch): # disable unlabeled data without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 if not without_unlabeled_data: logger.log_err( 'SSL_NULL is a supervised-only algorithm\n' 'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n' ) self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # both 'inp' and 'gt' are tuples inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # backward and update the task model loss = task_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()