def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() resulter, debugger = self.model.forward(inp, gt, False) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-')))
def _build_ssl_algorithm(self): """ Build the semi-supervised learning algorithm given in the script. """ for cname in self.args.models.keys(): self.model_dict[cname] = self.model.__dict__[ self.args.models[cname]]() self.criterion_dict[cname] = self.criterion.__dict__[ self.args.criterions[cname]]() self.lrer_dict[cname] = nnlrer.__dict__[self.args.lrers[cname]]( self.args) self.optimizer_dict[cname] = nnoptimizer.__dict__[ self.args.optimizers[cname]](self.args) logger.log_info('SSL algorithm: \n {0}\n'.format( self.args.ssl_algorithm)) logger.log_info('Models: ') self.ssl_algorithm = pixelssl.ssl_algorithm.__dict__[ self.args.ssl_algorithm].__dict__[self.args.ssl_algorithm]( self.args, self.model_dict, self.optimizer_dict, self.lrer_dict, self.criterion_dict, self.func.task_func()(self.args)) # check whether the SSL algorithm supports the given task if not self.TASK_TYPE in self.ssl_algorithm.SUPPORTED_TASK_TYPES: logger.log_err( 'SSL algorithm - {0} - supports task types {1}\n' 'However, the given task - {2} - belongs to {3}\n'.format( self.ssl_algorithm.NAME, self.ssl_algorithm.SUPPORTED_TASK_TYPES, self.args.task, self.TASK_TYPE))
def create_model(mclass, mname, **kwargs): """ Create a nn.Module and setup it on multiple GPUs. """ model = mclass(**kwargs) model = torch.nn.DataParallel(model) model = model.cuda() logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model))) return model
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): self.task_func = task_func # create models self.task_model = func.create_model(model_funcs[0], 'task_model', args=self.args).module self.rotation_classifier = RotationClassifer( self.task_func.ssls4l_rc_in_channels()) # wrap 'self.task_model' and 'self.rotation_classifier' into a single model self.model = WrappedS4LModel(self.args, self.task_model, self.rotation_classifier) self.model = nn.DataParallel(self.model).cuda() # call 'patch_replication_callback' to use the `sync_batchnorm` layer patch_replication_callback(self.model) self.models = {'model': self.model} # create optimizers self.optimizer = optimizer_funcs[0](self.model.module.param_groups) self.optimizers = {'optimizer': self.optimizer} # create lrers self.lrer = lrer_funcs[0](self.optimizer) self.lrers = {'lrer': self.lrer} # create criterions self.criterion = criterion_funcs[0](self.args) self.rotation_criterion = nn.CrossEntropyLoss() self.criterions = { 'criterion': self.criterion, 'rotation_criterion': self.rotation_criterion } # the batch size is doubled in S4L since it creates an extra rotated sample for each sample self.args.batch_size *= 2 self.args.labeled_batch_size *= 2 self.args.unlabeled_batch_size *= 2 logger.log_info('In SSL_S4L algorithm, batch size are doubled: \n' ' Total labeled batch size: {1}\n' ' Total unlabeled batch size: {2}\n'.format( self.args.lr, self.args.labeled_batch_size, self.args.unlabeled_batch_size)) self._algorithm_warn()
def _train(self, data_loader, epoch): # disable unlabeled data without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 if not without_unlabeled_data: logger.log_err( 'SSL_NULL is a supervised-only algorithm\n' 'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n' ) self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # both 'inp' and 'gt' are tuples inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # backward and update the task model loss = task_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() # TODO: support more ramp-up functions # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.optimizer.zero_grad() # ----------------------------------------------------------- # For Labeled Data # ----------------------------------------------------------- l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # forward the wrapped CCT model resulter, debugger = self.model.forward(l_inp, l_gt, False) l_pred = tool.dict_value(resulter, 'pred') l_activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) # ----------------------------------------------------------- # For Unlabeled Data # ----------------------------------------------------------- if self.args.unlabeled_batch_size > 0: ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size) ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the wrapped CCT model resulter, debugger = self.model.forward(ul_inp, ul_gt, True) ul_pred = tool.dict_value(resulter, 'pred') ul_activated_pred = tool.dict_value(resulter, 'activated_pred') ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds') cons_loss = tool.dict_value(resulter, 'cons_loss', err=True) cons_loss = cons_loss.mean() cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the wrapped CCT model loss = task_loss + cons_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'cons-loss: {meters[cons_loss]:.6f}\n' .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize(epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _validate(self, data_loader, epoch): self.meters.reset() self.s_model.eval() self.t_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt, _, _ = self._batch_prehandle(inp, gt, False) if len(inp) > 1 and idx == 0: self._inp_warn() if len(gt) > 1 and idx == 0: self._gt_warn() s_resulter, s_debugger = self.s_model.forward(inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') s_task_loss = self.s_criterion.forward(s_pred, gt, inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) t_resulter, t_debugger = self.t_model.forward(inp) if not 'pred' in t_resulter.keys( ) or not 'activated_pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') t_task_loss = self.s_criterion.forward(t_pred, gt, inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) t_pseudo_gt = [] for tap in t_activated_pred: t_pseudo_gt.append(tap.detach()) t_pseudo_gt = tuple(t_pseudo_gt) cons_loss = 0 for sap, tpg in zip(s_activated_pred, t_pseudo_gt): cons_loss += torch.mean(self.cons_criterion(sap, tpg)) cons_loss = self.args.cons_scale * torch.mean(cons_loss) self.meters.update('cons_loss', cons_loss.data) self.task_func.metrics(s_activated_pred, gt, inp, self.meters, id_str='student') self.task_func.metrics(t_activated_pred, gt, inp, self.meters, id_str='teacher') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # metrics metrics_info = {'student': '', 'teacher': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info( 'Validation metrics:\n student-metrics\t=>\t{0}\n teacher-metrics\t=>\t{1}\n' .format(metrics_info['student'].replace('_', '-'), metrics_info['teacher'].replace('_', '-')))
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size ubs = self.args.unlabeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 'inp' and 'gt' are tuples inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle( inp, gt, True) if len(inp) > 1 and idx == 0: self._inp_warn() if len(gt) > 1 and idx == 0: self._gt_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # ------------------------------------------------- # For Labeled Samples # ------------------------------------------------- l_inp = func.split_tensor_tuple(inp, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) # forward the labeled samples by the student model l_s_resulter, l_s_debugger = self.s_model.forward(l_inp) if not 'pred' in l_s_resulter.keys( ) or not 'activated_pred' in l_s_resulter.keys(): self._pred_err() l_s_pred = tool.dict_value(l_s_resulter, 'pred') l_s_activated_pred = tool.dict_value(l_s_resulter, 'activated_pred') # calculate the supervised task loss on the labeled samples task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # ------------------------------------------------- # For Unlabeled Samples # ------------------------------------------------- if self.args.unlabeled_batch_size > 0: u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the original samples by the teacher model with torch.no_grad(): u_t_resulter, u_t_debugger = self.t_model.forward(u_inp) if not 'pred' in u_t_resulter.keys( ) or not 'activated_pred' in u_t_resulter.keys(): self._pred_err() u_t_activated_pred = tool.dict_value(u_t_resulter, 'activated_pred') # mix the activated pred from the teacher model as the pseudo gt u_t_activated_pred_1 = func.split_tensor_tuple( u_t_activated_pred, 0, int(ubs / 2)) u_t_activated_pred_2 = func.split_tensor_tuple( u_t_activated_pred, int(ubs / 2), ubs) mix_u_t_activated_pred = [] mix_u_t_confidence = [] for up_1, up_2 in zip(u_t_activated_pred_1, u_t_activated_pred_2): mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2 mix_u_t_activated_pred.append(mp.detach()) # NOTE: here we just follow the official code of CutMix to calculate the confidence # but it is odd that all the samples use the same confidence (mean confidence) u_t_confidence = (mp.max(dim=1)[0] > self.args.cons_threshold).float().mean() mix_u_t_confidence.append(u_t_confidence.detach()) mix_u_t_activated_pred = tuple(mix_u_t_activated_pred) # forward the mixed samples by the student model u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp) if not 'pred' in u_s_resulter.keys( ) or not 'activated_pred' in u_s_resulter.keys(): self._pred_err() mix_u_s_activated_pred = tool.dict_value( u_s_resulter, 'activated_pred') # calculate the consistency constraint cons_loss = 0 for msap, mtap, confidence in zip(mix_u_s_activated_pred, mix_u_t_activated_pred, mix_u_t_confidence): cons_loss += torch.mean(self.cons_criterion( msap, mtap)) * confidence cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the student model loss = task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_t_activated_pred, 0, 1, reduce_dim=True), mix_u_mask[0]) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()
def _create_dataloader(self): """ Create data loaders for experiment. """ # --------------------------------------------------------------------- # create dataloder for training # --------------------------------------------------------------------- # ignore_unlabeled == False & unlabeled_batch_size != 0 # means that both labeled and unlabeled data are used with_unlabeled_data = not self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0 # ignore_unlabeled == True & unlabeled_batch_size == 0 # means that only the labeled data is used without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 labeled_train_samples, unlabeled_train_samples = 0, 0 if not self.args.validation: # ignore_unlabeled == True & unlabeled_batch_size != 0 -> error if self.args.ignore_unlabeled and self.args.unlabeled_batch_size != 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == True requires unlabeled_batch_size == 0\n' ) # ignore_unlabeled == False & unlabeled_batch_size == 0 -> error if not self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0: logger.log_err( 'Arguments conflict => ignore_unlabeled == False requires unlabeled_batch_size != 0\n' ) # calculate the number of trainsets trainset_num = 0 for key, value in self.args.trainset.items(): trainset_num += len(value) # calculate the number of unlabeledsets unlabeledset_num = 0 for key, value in self.args.unlabeledset.items(): unlabeledset_num += len(value) # if only one labeled training set and without any unlabeled set if trainset_num == 1 and unlabeledset_num == 0: trainset = self._load_dataset( list(self.args.trainset.keys())[0], list(self.args.trainset.values())[0][0]) labeled_train_samples = len(trainset.idxs) # if the 'sublabeled_path' is given sublabeled_prefix = None if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': if not os.path.exists(self.args.sublabeled_path): logger.log_err( 'Cannot find labeled file: {0}\n'.format( self.args.sublabeled_path)) else: with open(self.args.sublabeled_path) as f: sublabeled_prefix = [ line.strip() for line in f.read().splitlines() ] sublabeled_prefix = None if len( sublabeled_prefix) == 0 else sublabeled_prefix if sublabeled_prefix is not None: # wrap the trainset by 'SplitUnlabeledWrapper' trainset = nndata.SplitUnlabeledWrapper( trainset, sublabeled_prefix, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if 'sublabeled_prefix' is None but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' ) # if more than one labeled training sets are given or the unlabeled training sets are given elif trainset_num > 1 or unlabeledset_num > 0: # 'arg.sublabeled_path' is disabled if self.args.sublabeled_path is not None and self.args.sublabeled_path != '': logger.log_err( 'Multiple training datasets are given. \n' 'Inter-split unlabeled set is not allowed.\n' 'Please remove the argument \'sublabeled_path\' in the script\n' ) # load all training sets labeled_sets = [] for set_name, set_dirs in self.args.trainset.items(): for set_dir in set_dirs: labeled_sets.append( self._load_dataset(set_name, set_dir)) # load all extra unlabeled sets unlabeled_sets = [] # if any extra unlabeled set is given if unlabeledset_num > 0: for set_name, set_dirs in self.args.unlabeledset.items(): for set_dir in set_dirs: unlabeled_sets.append( self._load_dataset(set_name, set_dir)) # if unalbeledset_num == 0 but you want to use the unlabeled data for training elif with_unlabeled_data: logger.log_err( 'Try to use the unlabeled samples without any SSL dataset wrapper\n' 'Please add the argument \'unlabeledset\' in the script\n' ) # wrap both 'labeled_set' and 'unlabeled_set' by 'JointDatasetsWrapper' trainset = nndata.JointDatasetsWrapper( labeled_sets, unlabeled_sets, ignore_unlabeled=self.args.ignore_unlabeled) labeled_train_samples = len(trainset.labeled_idxs) unlabeled_train_samples = len(trainset.unlabeled_idxs) # if use labeled data only if without_unlabeled_data: self.train_loader = torch.utils.data.DataLoader( trainset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, pin_memory=True, drop_last=True) # if use both labeled and unlabeled data elif with_unlabeled_data: train_sampler = nndata.TwoStreamBatchSampler( trainset.labeled_idxs, trainset.unlabeled_idxs, self.args.labeled_batch_size, self.args.unlabeled_batch_size) self.train_loader = torch.utils.data.DataLoader( trainset, batch_sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True) # --------------------------------------------------------------------- # create dataloader for validation # --------------------------------------------------------------------- # calculate the number of valsets valset_num = 0 for key, value in self.args.valset.items(): valset_num += len(value) # if only one validation set is given if valset_num == 1: valset = self._load_dataset(list(self.args.valset.keys())[0], list(self.args.valset.values())[0][0], is_train=False) val_samples = len(valset.idxs) # if more than one validation sets are given elif valset_num > 1: valsets = [] for set_name, set_dirs in self.args.valset.items(): for set_dir in set_dirs: valsets.append( self._load_dataset(set_name, set_dir, is_train=False)) valset = nndata.JointDatasetsWrapper(valsets, [], ignore_unlabeled=True) val_samples = len(valset.labeled_idxs) # NOTE: batch size is set to 1 during the validation self.val_loader = torch.utils.data.DataLoader( valset, batch_size=1, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) # check the data loaders if self.train_loader is None and not self.args.validation: logger.log_err( 'Train data loader is required if validate mode is closed\n') elif self.val_loader is None and self.args.validation: logger.log_err( 'Validate data loader is required if validate mode is opened\n' ) elif self.val_loader is None: logger.log_warn( 'No validate data loader, there are no validation during the training\n' ) # set 'iters_per_epoch', which is required by ITER_LRERS self.args.iters_per_epoch = len( self.train_loader) if self.train_loader is not None else -1 logger.log_info( 'Dataset:\n' ' Trainset\t=>\tlabeled samples = {0}\t\tunlabeled samples = {1}\n' ' Valset\t=>\tsamples = {2}\n'.format(labeled_train_samples, unlabeled_train_samples, val_samples))
def _run(self): """ Main pipeline of experiment. Please override this function if you want a special pipeline. """ start_epoch = 0 if self.args.resume is not None and self.args.resume != '': logger.log_info('Load checkpoint from: {0}'.format( self.args.resume)) start_epoch = self.ssl_algorithm.load_checkpoint() if self.args.validation: if self.val_loader is None: logger.log_err('No data loader for validation.\n' 'Please set right \'valset\' in the script.\n') logger.log_info( ['=' * 78, '\nStart to validate model\n', '=' * 78]) with torch.no_grad(): self.ssl_algorithm.validate(self.val_loader, start_epoch - 1) return for epoch in range(start_epoch, self.args.epochs): timer = time.time() logger.log_info([ '=' * 78, '\nStart to train epoch-{0}\n'.format(epoch), '=' * 78 ]) self.ssl_algorithm.train(self.train_loader, epoch) if epoch % self.args.val_freq == 0 and self.val_loader is not None: logger.log_info([ '=' * 78, '\nStart to validate epoch-{0}\n'.format(epoch), '=' * 78 ]) with torch.no_grad(): self.ssl_algorithm.validate(self.val_loader, epoch) if (epoch + 1) % self.args.checkpoint_freq == 0: self.ssl_algorithm.save_checkpoint(epoch + 1) logger.log_info("Save checkpoint for epoch {0}".format(epoch)) logger.log_info( 'Finish epoch in {0} seconds\n'.format(time.time() - timer)) logger.log_info('Finish experiment {0}\n'.format(self.args.exp_id))
def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() self.d_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = self.criterion.forward(pred, gt, inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) unhandled_fake_confidence_map = tool.dict_value( d_resulter, 'confidence') fake_confidence_map, fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False) fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) unhandled_real_confidence_map = tool.dict_value( d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True) real_d_loss = self.d_criterion.forward(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(unhandled_fake_confidence_map[0])) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-')))
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() self.d_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() # ----------------------------------------------------------------------------- # step-1: train the task model # ----------------------------------------------------------------------------- self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # forward the FC discriminator # 'confidence_map' is a tensor d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) confidence_map = tool.dict_value(d_resulter, 'confidence') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # calculate the adversarial constraint # calculate the adversarial constraint for the labeled data if self.args.adv_for_labeled: l_confidence_map = confidence_map[:lbs, ...] # preprocess prediction and ground truch for the adversarial constraint l_adv_confidence_map, l_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True) l_adv_loss = self.d_criterion(l_adv_confidence_map, l_adv_confidence_gt) labeled_adv_loss = self.args.labeled_adv_scale * torch.mean( l_adv_loss) self.meters.update('labeled_adv_loss', labeled_adv_loss.data) else: labeled_adv_loss = 0 self.meters.update('labeled_adv_loss', labeled_adv_loss) # calculate the adversarial constraint for the unlabeled data if self.args.unlabeled_batch_size > 0: u_confidence_map = confidence_map[lbs:self.args.batch_size, ...] # preprocess prediction and ground truch for the adversarial constraint u_adv_confidence_map, u_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True) u_adv_loss = self.d_criterion(u_adv_confidence_map, u_adv_confidence_gt) unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean( u_adv_loss) self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss.data) else: unlabeled_adv_loss = 0 self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss) adv_loss = labeled_adv_loss + unlabeled_adv_loss # backward and update the task model loss = task_loss + adv_loss loss.backward() self.optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the FC discriminator # ----------------------------------------------------------------------------- self.d_optimizer.zero_grad() # forward the task prediction (fake) if self.args.unlabeled_for_discriminator: fake_pred = activated_pred[0].detach() else: fake_pred = activated_pred[0][:lbs, ...].detach() d_resulter, d_debugger = self.d_model.forward(fake_pred) fake_confidence_map = tool.dict_value(d_resulter, 'confidence') l_fake_confidence_map = fake_confidence_map[:lbs, ...] l_fake_confidence_map, l_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False) if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0: u_fake_confidence_map = fake_confidence_map[ lbs:self.args.batch_size, ...] u_fake_confidence_map, u_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False) fake_confidence_map = torch.cat( (l_fake_confidence_map, u_fake_confidence_map), dim=0) fake_confidence_gt = torch.cat( (l_fake_confidence_gt, u_fake_confidence_gt), dim=0) else: fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) # forward the ground truth (real) # convert the format of ground truch real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input( l_gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) real_confidence_map = tool.dict_value(d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True) real_d_loss = self.d_criterion(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) # backward and update the FC discriminator d_loss = (fake_d_loss + real_d_loss) / 2 d_loss.backward() self.d_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t' 'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None if self.args.unlabeled_batch_size > 0: u_inp_sample = func.split_tensor_tuple(inp, lbs, lbs + 1, reduce_dim=True) u_pred_sample = func.split_tensor_tuple(activated_pred, lbs, lbs + 1, reduce_dim=True) u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs]) self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(confidence_map[0]), u_inp_sample, u_pred_sample, u_cmap_sample) # the FC discriminator uses polynomiallr [ITER_LRERS] self.d_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() original_lbs = int(self.args.labeled_batch_size / 2) original_bs = int(self.args.batch_size / 2) self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # the rotated samples are generated in the 'self._batch_prehandle' function # both 'inp' and 'gt' are tuples # the last element in the tuple 'gt' is the ground truth of the rotation angle inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) - 1 > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the model resulter, debugger = self.model.forward(inp) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') pred_rotation = tool.dict_value(resulter, 'rotation') # calculate the supervised task constraint on the un-rotated labeled data l_pred = func.split_tensor_tuple(pred, 0, original_lbs) l_gt = func.split_tensor_tuple(gt, 0, original_lbs) l_inp = func.split_tensor_tuple(inp, 0, original_lbs) unrotated_task_loss = self.criterion.forward( l_pred, l_gt[:-1], l_inp) unrotated_task_loss = torch.mean(unrotated_task_loss) self.meters.update('unrotated_task_loss', unrotated_task_loss.data) # calculate the supervised task constraint on the rotated labeled data l_rotated_pred = func.split_tensor_tuple( pred, original_bs, original_bs + original_lbs) l_rotated_gt = func.split_tensor_tuple(gt, original_bs, original_bs + original_lbs) l_rotated_inp = func.split_tensor_tuple(inp, original_bs, original_bs + original_lbs) rotated_task_loss = self.criterion.forward(l_rotated_pred, l_rotated_gt[:-1], l_rotated_inp) rotated_task_loss = self.args.rotated_sup_scale * torch.mean( rotated_task_loss) self.meters.update('rotated_task_loss', rotated_task_loss.data) task_loss = unrotated_task_loss + rotated_task_loss # calculate the self-supervised rotation constraint rotation_loss = self.rotation_criterion.forward( pred_rotation, gt[-1]) rotation_loss = self.args.rotation_scale * torch.mean( rotation_loss) self.meters.update('rotation_loss', rotation_loss.data) # backward and update the model loss = task_loss + rotation_loss loss.backward() self.optimizer.step() # calculate the accuracy of the rotation classifier _, angle_idx = pred_rotation.topk(1, 1, True, True) angle_idx = angle_idx.t() rotation_acc = angle_idx.eq(gt[-1].view(1, -1).expand_as(angle_idx)) rotation_acc = rotation_acc.view(-1).float().sum( 0, keepdim=True).mul_(100.0 / self.args.batch_size) self.meters.update('rotation_acc', rotation_acc.data[0]) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t' 'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n' ' rotation-{3}\t=>\t' 'rotation-loss: {meters[rotation_loss]:.6f}\t' 'rotation-acc: {meters[rotation_acc]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 's_inp', 't_inp' and 'gt' are tuples s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # forward the student model s_resulter, s_debugger = self.s_model.forward(s_inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) # forward the teacher model with torch.no_grad(): t_resulter, t_debugger = self.t_model.forward(t_inp) if not 'pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') # calculate 't_task_loss' for recording l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs) l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs) t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) # calculate the consistency constraint from the teacher model to the student model t_pseudo_gt = Variable(t_pred[0].detach().data, requires_grad=False) if self.args.cons_for_labeled: cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt) elif self.args.unlabeled_batch_size > 0: cons_loss = self.cons_criterion(s_pred[0][lbs:, ...], t_pseudo_gt[lbs:, ...]) else: cons_loss = self.zero_tensor cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) # backward and update the student model loss = s_task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.train() self.r_model.train() self.fd_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the dynamic consistency constraint cur_steps = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.dc_rampup_epochs dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps) # ----------------------------------------------------------------------------- # step-0: pre-forwarding to save GPU memory # - forward the task models and the flaw detector # - generate pseudo ground truth for the unlabeled data if the dynamic # consistency constraint is enabled # ----------------------------------------------------------------------------- with torch.no_grad(): l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') # 'l_flawmap' and 'r_flawmap' will be used in step-2 fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None # generate the pseudo ground truth for the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: with torch.no_grad(): l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) # ----------------------------------------------------------------------------- # step-1: train the task models # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = False # train the 'l' task model l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, dc_rampup_scale) self.l_optimizer.zero_grad() l_loss.backward() self.l_optimizer.step() # train the 'r' task model r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, dc_rampup_scale) self.r_optimizer.zero_grad() r_loss.backward() self.r_optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the flaw detector # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = True # generate the ground truth for the flaw detector (on labeled data only) with torch.no_grad(): l_flawmap_gt = self.fdgt_generator.forward( l_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( l_gt[0][:lbs, ...])) r_flawmap_gt = self.fdgt_generator.forward( r_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( r_gt[0][:lbs, ...])) l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...], l_flawmap_gt) l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss) self.meters.update('l_fd_loss', l_fd_loss.data) r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...], r_flawmap_gt) r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss) self.meters.update('r_fd_loss', r_fd_loss.data) fd_loss = (l_fd_loss + r_fd_loss) / 2 self.fd_optimizer.zero_grad() fd_loss.backward() self.fd_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # the flaw detector uses polynomiallr [ITER_LRERS] self.fd_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step()
def _validate(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.eval() self.r_model.eval() self.fd_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None if self.args.ssl_mode in [MODE_GCT, MODE_DC]: l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, 1) r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, 1) self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # metrics metrics_info = {'l': '', 'r': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6f}\t'.format( key, self.meters[key]) logger.log_info( 'Validation metrics:\n l-metrics\t=>\t{0}\n r-metrics\t=>\t{1}\n' .format(metrics_info['l'].replace('_', '-'), metrics_info['r'].replace('_', '-')))
def _preprocess_arguments(self): """ Preprocess the arguments in the script. """ # create the output folder to store the results self.args.out_path = "{root}/{exp_id}/{date:%Y-%m-%d_%H:%M:%S}/".format( root=self.args.out_path, exp_id=self.args.exp_id, date=datetime.now()) if not os.path.exists(self.args.out_path): os.makedirs(self.args.out_path) # prepare logger exp_op = 'val' if self.args.validation else 'train' logger.log_mode(self.args.debug) logger.log_file( os.path.join(self.args.out_path, '{0}.log'.format(exp_op)), self.args.debug) logger.log_info('Result folder: \n {0} \n'.format(self.args.out_path)) # print experimental args cmd.print_args() # set task name self.args.task = self.NAME # check the task-specific components dicts required by the SSL algorithm if not len(self.args.models) == len(self.args.optimizers) == len( self.args.lrers) == len(self.args.criterions): logger.log_err( 'Condition:\n' '\tlen(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions\n' 'is not satisfied in the script\n') for (model, criterion, optimizer, lrer) in \ zip(self.args.models.values(), self.args.criterions.values(), self.args.optimizers.values(), self.args.lrers.values()): if model not in self.model.__dict__: logger.log_err( 'Unsupport model: {0} for task: {1}\n' 'Please add the export function in task\'s \'model.py\'\n'. format(model, self.args.task)) elif criterion not in self.criterion.__dict__: logger.log_err( 'Unsupport criterion: {0} for task: {1}\n' 'Please add the export function in task\'s \'criterion.py\'\n' .format(criterion, self.args.task)) elif optimizer not in nnoptimizer.__dict__: logger.log_err( 'Unsupport optimizer: {0}\n' 'Please implement the optimizer wrapper in \'pixelssl/nn/optimizer.py\'\n' .format(optimizer)) elif lrer not in nnlrer.__dict__: logger.log_err( 'Unsupport learning rate scheduler: {0}\n' 'Please implement lr scheduler wrapper in \'pixelssl/nn/lrer.py\'\n' .format(lrer)) # check the types of lrers for lrer in self.args.lrers.values(): if lrer in nnlrer.EPOCH_LRERS: is_epoch_lrer = True elif lrer in nnlrer.ITER_LRERS: is_epoch_lrer = False else: logger.log_err( 'Unknown learning rate scheduler ({0}) type\n' 'Please add it into either EPOCH_LRERS or ITER_LRERS in \'pixelssl/nn/lrer.py\'\n' 'PixelSSL supports: \n' ' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'.format( lrer, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS)) if self.args.is_epoch_lrer is None: self.args.is_epoch_lrer = is_epoch_lrer elif self.args.is_epoch_lrer != is_epoch_lrer: logger.log_err( 'Unmatched lr scheduler types\t=>\t{0}\n' 'All lrers of the task models should have the same types (either EPOCH_LRERS or ITER_LRERS)\n' 'PixelSSL supports: \n' ' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'.format( self.args.lrers, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS)) self.args.checkpoint_path = os.path.join(self.args.out_path, 'ckpt') if not os.path.exists(self.args.checkpoint_path): os.makedirs(self.args.checkpoint_path) if self.args.visualize: self.args.visual_debug_path = os.path.join(self.args.out_path, 'visualization/debug') self.args.visual_train_path = os.path.join(self.args.out_path, 'visualization/train') self.args.visual_val_path = os.path.join(self.args.out_path, 'visualization/val') for vpath in [ self.args.visual_debug_path, self.args.visual_train_path, self.args.visual_val_path ]: if not os.path.exists(vpath): os.makedirs(vpath) # handle argumens for multiply GPUs training self.args.gpus = torch.cuda.device_count() if self.args.gpus < 1: logger.log_err('No GPU be detected\n' 'PixelSSL requires at least one Nvidia GPU\n') logger.log_info('GPU: \n Total GPU(s): {0}'.format(self.args.gpus)) self.args.lr *= self.args.gpus self.args.num_workers *= self.args.gpus self.args.batch_size *= self.args.gpus self.args.unlabeled_batch_size *= self.args.gpus # TODO: support unsupervised and self-supervised training if self.args.unlabeled_batch_size >= self.args.batch_size: logger.log_err( 'The argument \'unlabeled_batch_size\' ({0}) should be smaller than \'batch_size\' ({1}) ' 'since PixelSSL only supports semi-supervised learning now\n') self.args.labeled_batch_size = self.args.batch_size - self.args.unlabeled_batch_size logger.log_info( ' Total learn rate: {0}\n Total labeled batch size: {1}\n' ' Total unlabeled batch size: {2}\n Total data workers: {3}\n'. format(self.args.lr, self.args.labeled_batch_size, self.args.unlabeled_batch_size, self.args.num_workers))