def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): self.task_func = task_func # create models self.model = func.create_model(model_funcs[0], 'model', args=self.args) self.d_model = func.create_model(FCDiscriminator, 'd_model', in_channels=self.task_func.ssladv_fcd_in_channels()) # call 'patch_replication_callback' to enable the `sync_batchnorm` layer patch_replication_callback(self.model) patch_replication_callback(self.d_model) self.models = {'model': self.model, 'd_model': self.d_model} # create optimizers self.optimizer = optimizer_funcs[0](self.model.module.param_groups) self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.d_model.parameters()), lr=self.args.discriminator_lr, betas=(0.9, 0.99)) self.optimizers = {'optimizer': self.optimizer, 'd_optimizer': self.d_optimizer} # create lrers self.lrer = lrer_funcs[0](self.optimizer) self.d_lrer = PolynomialLR(self.d_optimizer, self.args.epochs, self.args.iters_per_epoch, power=self.args.discriminator_power, last_epoch=-1) self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer} # create criterions self.criterion = criterion_funcs[0](self.args) self.d_criterion = FCDiscriminatorCriterion() self.criterions = {'criterion': self.criterion, 'd_criterion': self.d_criterion} self._algorithm_warn()
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): self.task_func = task_func # create models # 'l_' denotes the first task model while 'r_' denotes the second task model self.l_model = func.create_model(model_funcs[0], 'l_model', args=self.args) self.r_model = func.create_model(model_funcs[1], 'r_model', args=self.args) self.fd_model = func.create_model( FlawDetector, 'fd_model', in_channels=self.task_func.sslgct_fd_in_channels()) # call 'patch_replication_callback' to enable the `sync_batchnorm` layer patch_replication_callback(self.l_model) patch_replication_callback(self.r_model) patch_replication_callback(self.fd_model) self.models = { 'l_model': self.l_model, 'r_model': self.r_model, 'fd_model': self.fd_model } # create optimizers self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups) self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups) self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.fd_model.parameters()), lr=self.args.fd_lr, betas=(0.9, 0.99)) self.optimizers = { 'l_optimizer': self.l_optimizer, 'r_optimizer': self.r_optimizer, 'fd_optimizer': self.fd_optimizer } # create lrers self.l_lrer = lrer_funcs[0](self.l_optimizer) self.r_lrer = lrer_funcs[1](self.r_optimizer) self.fd_lrer = PolynomialLR(self.fd_optimizer, self.args.epochs, self.args.iters_per_epoch, power=0.9, last_epoch=-1) self.lrers = { 'l_lrer': self.l_lrer, 'r_lrer': self.r_lrer, 'fd_lrer': self.fd_lrer } # create criterions self.l_criterion = criterion_funcs[0](self.args) self.r_criterion = criterion_funcs[1](self.args) self.fd_criterion = FlawDetectorCriterion() self.dc_criterion = torch.nn.MSELoss() self.criterions = { 'l_criterion': self.l_criterion, 'r_criterion': self.r_criterion, 'fd_criterion': self.fd_criterion, 'dc_criterion': self.dc_criterion } # build the extra modules required by GCT self.flawmap_handler = nn.DataParallel(FlawmapHandler( self.args)).cuda() self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda() self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda()
class SSLGCT(ssl_base._SSLBase): NAME = 'ssl_gct' SUPPORTED_TASK_TYPES = [REGRESSION, CLASSIFICATION] def __init__(self, args): super(SSLGCT, self).__init__(args) # define the task model and the flaw detector self.l_model, self.r_model, self.fd_model = None, None, None self.l_optimizer, self.r_optimizer, self.fd_optimizer = None, None, None self.l_lrer, self.r_lrer, self.fd_lrer = None, None, None self.l_criterion, self.r_criterion, self.fd_criterion = None, None, None # define the extra modules required by GCT self.flawmap_handler = None self.dcgt_generator = None self.fdgt_generator = None self.zero_df_gt = torch.zeros( [self.args.batch_size, 1, self.args.im_size, self.args.im_size]).cuda() # prepare the arguments for multiple GPUs self.args.fd_lr *= self.args.gpus # check SSL arguments if self.args.unlabeled_batch_size > 0: if self.args.ssl_mode in [MODE_GCT, MODE_FC]: if self.args.fc_ssl_scale < 0: logger.log_err( 'The argument - fc_ssl_scale - is not set (or invalid)\n' 'You enable the flaw correction constraint\n' 'Please set - fc_ssl_scale >= 0 - for training\n') if self.args.ssl_mode in [MODE_GCT, MODE_DC]: if self.args.dc_rampup_epochs < 0: logger.log_err( 'The argument - dc_rampup_epochs - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_rampup_epochs >= 0 - for training\n') elif self.args.dc_ssl_scale < 0: logger.log_err( 'The argument - dc_ssl_scale - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_ssl_scale >= 0 - for training\n') elif self.args.dc_threshold < 0: logger.log_err( 'The argument - dc_threshold - is not set (or invalid)\n' 'You enable the dynamic consistency constraint\n' 'Please set - dc_threshold >= 0 - for training\n') elif self.args.mu < 0: logger.log_err( 'The argument - mu - is not set (or invalid)\n' 'Please set - 0 < mu <= 1 - for training\n') elif self.args.nu < 0: logger.log_err( 'The argument - nu - is not set (or invalid)\n' 'Please set - nu > 0 - for training\n') def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): self.task_func = task_func # create models # 'l_' denotes the first task model while 'r_' denotes the second task model self.l_model = func.create_model(model_funcs[0], 'l_model', args=self.args) self.r_model = func.create_model(model_funcs[1], 'r_model', args=self.args) self.fd_model = func.create_model( FlawDetector, 'fd_model', in_channels=self.task_func.sslgct_fd_in_channels()) # call 'patch_replication_callback' to enable the `sync_batchnorm` layer patch_replication_callback(self.l_model) patch_replication_callback(self.r_model) patch_replication_callback(self.fd_model) self.models = { 'l_model': self.l_model, 'r_model': self.r_model, 'fd_model': self.fd_model } # create optimizers self.l_optimizer = optimizer_funcs[0](self.l_model.module.param_groups) self.r_optimizer = optimizer_funcs[1](self.r_model.module.param_groups) self.fd_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.fd_model.parameters()), lr=self.args.fd_lr, betas=(0.9, 0.99)) self.optimizers = { 'l_optimizer': self.l_optimizer, 'r_optimizer': self.r_optimizer, 'fd_optimizer': self.fd_optimizer } # create lrers self.l_lrer = lrer_funcs[0](self.l_optimizer) self.r_lrer = lrer_funcs[1](self.r_optimizer) self.fd_lrer = PolynomialLR(self.fd_optimizer, self.args.epochs, self.args.iters_per_epoch, power=0.9, last_epoch=-1) self.lrers = { 'l_lrer': self.l_lrer, 'r_lrer': self.r_lrer, 'fd_lrer': self.fd_lrer } # create criterions self.l_criterion = criterion_funcs[0](self.args) self.r_criterion = criterion_funcs[1](self.args) self.fd_criterion = FlawDetectorCriterion() self.dc_criterion = torch.nn.MSELoss() self.criterions = { 'l_criterion': self.l_criterion, 'r_criterion': self.r_criterion, 'fd_criterion': self.fd_criterion, 'dc_criterion': self.dc_criterion } # build the extra modules required by GCT self.flawmap_handler = nn.DataParallel(FlawmapHandler( self.args)).cuda() self.dcgt_generator = nn.DataParallel(DCGTGenerator(self.args)).cuda() self.fdgt_generator = nn.DataParallel(FDGTGenerator(self.args)).cuda() def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.train() self.r_model.train() self.fd_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the dynamic consistency constraint cur_steps = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.dc_rampup_epochs dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps) # ----------------------------------------------------------------------------- # step-0: pre-forwarding to save GPU memory # - forward the task models and the flaw detector # - generate pseudo ground truth for the unlabeled data if the dynamic # consistency constraint is enabled # ----------------------------------------------------------------------------- with torch.no_grad(): l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') # 'l_flawmap' and 'r_flawmap' will be used in step-2 fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None # generate the pseudo ground truth for the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: with torch.no_grad(): l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) # ----------------------------------------------------------------------------- # step-1: train the task models # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = False # train the 'l' task model l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, dc_rampup_scale) self.l_optimizer.zero_grad() l_loss.backward() self.l_optimizer.step() # train the 'r' task model r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, dc_rampup_scale) self.r_optimizer.zero_grad() r_loss.backward() self.r_optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the flaw detector # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = True # generate the ground truth for the flaw detector (on labeled data only) with torch.no_grad(): l_flawmap_gt = self.fdgt_generator.forward( l_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( l_gt[0][:lbs, ...])) r_flawmap_gt = self.fdgt_generator.forward( r_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( r_gt[0][:lbs, ...])) l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...], l_flawmap_gt) l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss) self.meters.update('l_fd_loss', l_fd_loss.data) r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...], r_flawmap_gt) r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss) self.meters.update('r_fd_loss', r_fd_loss.data) fd_loss = (l_fd_loss + r_fd_loss) / 2 self.fd_optimizer.zero_grad() fd_loss.backward() self.fd_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # the flaw detector uses polynomiallr [ITER_LRERS] self.fd_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step() def _validate(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.eval() self.r_model.eval() self.fd_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None if self.args.ssl_mode in [MODE_GCT, MODE_DC]: l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, 1) r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, 1) self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # metrics metrics_info = {'l': '', 'r': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6f}\t'.format( key, self.meters[key]) logger.log_info( 'Validation metrics:\n l-metrics\t=>\t{0}\n r-metrics\t=>\t{1}\n' .format(metrics_info['l'].replace('_', '-'), metrics_info['r'].replace('_', '-'))) def _save_checkpoint(self, epoch): state = { 'algorithm': self.NAME, 'epoch': epoch + 1, 'l_model': self.l_model.state_dict(), 'r_model': self.r_model.state_dict(), 'fd_model': self.fd_model.state_dict(), 'l_optimizer': self.l_optimizer.state_dict(), 'r_optimizer': self.r_optimizer.state_dict(), 'fd_optimizer': self.fd_optimizer.state_dict(), 'l_lrer': self.l_lrer.state_dict(), 'r_lrer': self.r_lrer.state_dict(), 'fd_lrer': self.fd_lrer.state_dict() } checkpoint = os.path.join(self.args.checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch)) torch.save(state, checkpoint) def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err( 'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.l_model.load_state_dict(checkpoint['l_model']) self.r_model.load_state_dict(checkpoint['r_model']) self.fd_model.load_state_dict(checkpoint['fd_model']) self.l_optimizer.load_state_dict(checkpoint['l_optimizer']) self.r_optimizer.load_state_dict(checkpoint['r_optimizer']) self.fd_optimizer.load_state_dict(checkpoint['fd_optimizer']) self.l_lrer.load_state_dict(checkpoint['l_lrer']) self.r_lrer.load_state_dict(checkpoint['r_lrer']) self.fd_lrer.load_state_dict(checkpoint['fd_lrer']) return checkpoint['epoch'] def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt, fc_mask, dc_rampup_scale): if mid == 'l': model, criterion = self.l_model, self.l_criterion elif mid == 'r': model, criterion = self.r_model, self.r_criterion else: model, criterion = None, None # forward the task model resulter, debugger = model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( inp, activated_pred[0]) flawmap = tool.dict_value(fd_resulter, 'flawmap') # calculate the supervised task constraint on the labeled data labeled_pred = func.split_tensor_tuple(pred, 0, lbs) labeled_gt = func.split_tensor_tuple(gt, 0, lbs) labeled_inp = func.split_tensor_tuple(inp, 0, lbs) task_loss = torch.mean( criterion.forward(labeled_pred, labeled_gt, labeled_inp)) self.meters.update('{0}_task_loss'.format(mid), task_loss.data) # calculate the flaw correction constraint if self.args.ssl_mode in [MODE_GCT, MODE_FC]: if flawmap.shape == self.zero_df_gt.shape: fc_ssl_loss = self.fd_criterion.forward(flawmap, self.zero_df_gt, is_ssl=True, reduction=False) else: fc_ssl_loss = self.fd_criterion.forward( flawmap, torch.zeros(flawmap.shape).cuda(), is_ssl=True, reduction=False) if self.args.ssl_mode == MODE_GCT: fc_ssl_loss = fc_mask * fc_ssl_loss fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss) self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data) else: fc_ssl_loss = 0 self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss) # calculate the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: if dc_gt is None: logger.log_err( 'The dynamic consistency constraint is enabled, ' 'but no pseudo ground truth is given.') dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt) dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean( dc_ssl_loss) self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data) else: dc_ssl_loss = 0 self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss) with torch.no_grad(): flawmap_gt = self.fdgt_generator.forward( activated_pred[0], self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0])) # for validation if not is_train: fd_loss = self.args.fd_scale * self.fd_criterion.forward( flawmap, flawmap_gt) self.meters.update('{0}_fd_loss'.format(mid), torch.mean(fd_loss).data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str=mid) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: with torch.no_grad(): handled_flawmap = self.flawmap_handler(flawmap)[0] self._visualize( epoch, idx, is_train, mid, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), handled_flawmap, flawmap_gt[0], dc_gt[0]) loss = task_loss + fc_ssl_loss + dc_ssl_loss return loss def _visualize(self, epoch, idx, is_train, id_str, inp, pred, gt, flawmap, flawmap_gt, dc_gt=None): visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx)) self.task_func.visualize(out_path, id_str=id_str, inp=inp, pred=pred, gt=gt) flawmap = flawmap[0].data.cpu().numpy() Image.fromarray( (flawmap * 255).astype('uint8'), mode='L').save(out_path + '_{0}-fmap.png'.format(id_str)) flawmap_gt = flawmap_gt[0].data.cpu().numpy() Image.fromarray( (flawmap_gt * 255).astype('uint8'), mode='L').save(out_path + '_{0}-fmap-gt.png'.format(id_str)) if dc_gt is not None: self.task_func.visualize(out_path, id_str=id_str + '_dc', inp=None, pred=[dc_gt], gt=None) def _batch_prehandle(self, inp, gt): # add extra data augmentation process here if necessary inp_var = [] for i in inp: inp_var.append(Variable(i).cuda()) inp = tuple(inp_var) gt_var = [] for g in gt: gt_var.append(Variable(g).cuda()) gt = tuple(gt_var) # augment 'inp' for different task models here if necessary l_inp, r_inp = inp, inp l_gt, r_gt = gt, gt return (l_inp, l_gt), (r_inp, r_gt) def _inp_warn(self): logger.log_warn( 'More than one ground truth of the task model is given in SSL_GCT\n' 'You try to train the task model with more than one (pred & gt) pairs\n' 'Please make sure that:\n' ' (1) The prediction tuple has the same size as the ground truth tuple\n' ' (2) The elements with the same index in the two tuples are corresponding\n' ' (3) The first element of (pred & gt) will be used to train the flaw detector\n' ' and calculate the SSL constraints\n' 'Please implement a new SSL algorithm if you want a variant of SSL_GCT with\n' 'multiple flaw detectors (for multiple predictions)\n') def _pred_err(self): logger.log_err( 'In SSL_GCT, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n')
class SSLADV(ssl_base._SSLBase): NAME = 'ssl_adv' SUPPORTED_TASK_TYPES = [REGRESSION, CLASSIFICATION] def __init__(self, args): super(SSLADV, self).__init__(args) # define the task model and the FC discriminator self.model, self.d_model = None, None self.optimizer, self.d_optimizer = None, None self.lrer, self.d_lrer = None, None self.criterion, self.d_criterion = None, None # prepare the arguments for multiple GPUs self.args.discriminator_lr *= self.args.gpus # check SSL arguments if self.args.adv_for_labeled: if self.args.labeled_adv_scale < 0: logger.log_err( 'The argument - labeled_adv_scale - is not set (or invalid)\n' 'You set argument - adv_for_labeled - to True\n' 'Please set - labeled_adv_scale >= 0 - for calculating the' 'adversarial loss on the labeled data\n') if self.args.unlabeled_batch_size > 0: if self.args.unlabeled_adv_scale < 0: logger.log_err( 'The argument - unlabeled_adv_scale - is not set (or invalid)\n' 'You set argument - unlabeled_batch_size - larger than 0\n' 'Please set - unlabeled_adv_scale >= 0 - for calculating the' 'adversarial loss on the unlabeled data\n') def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): self.task_func = task_func # create models self.model = func.create_model(model_funcs[0], 'model', args=self.args) self.d_model = func.create_model( FCDiscriminator, 'd_model', in_channels=self.task_func.ssladv_fcd_in_channels()) # call 'patch_replication_callback' to enable the `sync_batchnorm` layer patch_replication_callback(self.model) patch_replication_callback(self.d_model) self.models = {'model': self.model, 'd_model': self.d_model} # create optimizers self.optimizer = optimizer_funcs[0](self.model.module.param_groups) self.d_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.d_model.parameters()), lr=self.args.discriminator_lr, betas=(0.9, 0.99)) self.optimizers = { 'optimizer': self.optimizer, 'd_optimizer': self.d_optimizer } # create lrers self.lrer = lrer_funcs[0](self.optimizer) self.d_lrer = PolynomialLR(self.d_optimizer, self.args.epochs, self.args.iters_per_epoch, power=self.args.discriminator_power, last_epoch=-1) self.lrers = {'lrer': self.lrer, 'd_lrer': self.d_lrer} # create criterions self.criterion = criterion_funcs[0](self.args) self.d_criterion = FCDiscriminatorCriterion() self.criterions = { 'criterion': self.criterion, 'd_criterion': self.d_criterion } self._algorithm_warn() def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() self.d_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() # ----------------------------------------------------------------------------- # step-1: train the task model # ----------------------------------------------------------------------------- self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # forward the FC discriminator # 'confidence_map' is a tensor d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) confidence_map = tool.dict_value(d_resulter, 'confidence') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # calculate the adversarial constraint # calculate the adversarial constraint for the labeled data if self.args.adv_for_labeled: l_confidence_map = confidence_map[:lbs, ...] # preprocess prediction and ground truch for the adversarial constraint l_adv_confidence_map, l_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True) l_adv_loss = self.d_criterion(l_adv_confidence_map, l_adv_confidence_gt) labeled_adv_loss = self.args.labeled_adv_scale * torch.mean( l_adv_loss) self.meters.update('labeled_adv_loss', labeled_adv_loss.data) else: labeled_adv_loss = 0 self.meters.update('labeled_adv_loss', labeled_adv_loss) # calculate the adversarial constraint for the unlabeled data if self.args.unlabeled_batch_size > 0: u_confidence_map = confidence_map[lbs:self.args.batch_size, ...] # preprocess prediction and ground truch for the adversarial constraint u_adv_confidence_map, u_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True) u_adv_loss = self.d_criterion(u_adv_confidence_map, u_adv_confidence_gt) unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean( u_adv_loss) self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss.data) else: unlabeled_adv_loss = 0 self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss) adv_loss = labeled_adv_loss + unlabeled_adv_loss # backward and update the task model loss = task_loss + adv_loss loss.backward() self.optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the FC discriminator # ----------------------------------------------------------------------------- self.d_optimizer.zero_grad() # forward the task prediction (fake) if self.args.unlabeled_for_discriminator: fake_pred = activated_pred[0].detach() else: fake_pred = activated_pred[0][:lbs, ...].detach() d_resulter, d_debugger = self.d_model.forward(fake_pred) fake_confidence_map = tool.dict_value(d_resulter, 'confidence') l_fake_confidence_map = fake_confidence_map[:lbs, ...] l_fake_confidence_map, l_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False) if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0: u_fake_confidence_map = fake_confidence_map[ lbs:self.args.batch_size, ...] u_fake_confidence_map, u_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False) fake_confidence_map = torch.cat( (l_fake_confidence_map, u_fake_confidence_map), dim=0) fake_confidence_gt = torch.cat( (l_fake_confidence_gt, u_fake_confidence_gt), dim=0) else: fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) # forward the ground truth (real) # convert the format of ground truch real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input( l_gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) real_confidence_map = tool.dict_value(d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True) real_d_loss = self.d_criterion(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) # backward and update the FC discriminator d_loss = (fake_d_loss + real_d_loss) / 2 d_loss.backward() self.d_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t' 'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None if self.args.unlabeled_batch_size > 0: u_inp_sample = func.split_tensor_tuple(inp, lbs, lbs + 1, reduce_dim=True) u_pred_sample = func.split_tensor_tuple(activated_pred, lbs, lbs + 1, reduce_dim=True) u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs]) self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(confidence_map[0]), u_inp_sample, u_pred_sample, u_cmap_sample) # the FC discriminator uses polynomiallr [ITER_LRERS] self.d_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step() def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() self.d_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = self.criterion.forward(pred, gt, inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) unhandled_fake_confidence_map = tool.dict_value( d_resulter, 'confidence') fake_confidence_map, fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False) fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) unhandled_real_confidence_map = tool.dict_value( d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True) real_d_loss = self.d_criterion.forward(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(unhandled_fake_confidence_map[0])) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-'))) def _save_checkpoint(self, epoch): state = { 'algorithm': self.NAME, 'epoch': epoch, 'model': self.model.state_dict(), 'd_model': self.d_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'lrer': self.lrer.state_dict(), 'd_lrer': self.d_lrer.state_dict() } checkpoint = os.path.join(self.args.checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch)) torch.save(state, checkpoint) def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err( 'Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.model.load_state_dict(checkpoint['model']) self.d_model.load_state_dict(checkpoint['d_model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.d_optimizer.load_state_dict(checkpoint['d_optimizer']) self.lrer.load_state_dict(checkpoint['lrer']) self.d_lrer.load_state_dict(checkpoint['d_lrer']) return checkpoint['epoch'] # ------------------------------------------------------------------------------------------- # Tool Functions for SSL_ADV # ------------------------------------------------------------------------------------------- def _visualize(self, epoch, idx, is_train, l_inp, l_pred, l_gt, l_cmap, u_inp=None, u_pred=None, u_cmap=None): # 'cmap' is the output of the FC discriminator visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx)) self.task_func.visualize(out_path, id_str='labeled', inp=l_inp, pred=l_pred, gt=l_gt) l_cmap = l_cmap[0].data.cpu().numpy() Image.fromarray((l_cmap * 255).astype('uint8'), mode='L').save(out_path + '_labeled-cmap.png') if u_inp is not None and u_pred and not None and u_cmap is not None: self.task_func.visualize(out_path, id_str='unlabeled', inp=u_inp, pred=u_pred, gt=None) u_cmap = u_cmap[0].data.cpu().numpy() Image.fromarray((u_cmap * 255).astype('uint8'), mode='L').save(out_path + '_unlabeled-cmap.png') def _batch_prehandle(self, inp, gt): # add extra data augmentation process here if necessary inp_var = [] for i in inp: inp_var.append(Variable(i).cuda()) inp = tuple(inp_var) gt_var = [] for g in gt: gt_var.append(Variable(g).cuda()) gt = tuple(gt_var) return inp, gt def _algorithm_warn(self): logger.log_warn( 'This SSL_ADV algorithm reproduces the SSL algorithm from the paper:\n' ' \'Adversarial Learning for Semi-supervised Semantic Segmentation\'\n' 'The main differences between this implementation and the original paper are:\n' ' (1) This implementation does not support the constraint named \'L_semi\' in the\n' ' original paper since it can only be used for pixel-wise classification\n' '\nThe semi-supervised constraint in this implementation refer to the constraint\n' 'named \'L_adv\' in the original paper\n' '\nSame as the original paper, the FC discriminator is trained by the Adam optimizer\n' 'with the PolynomialLR scheduler\n') def _inp_warn(self): logger.log_warn( 'More than one ground truth of the task model is given in SSL_ADV\n' 'You try to train the task model with more than one (pred & gt) pairs\n' 'Please make sure that:\n' ' (1) The prediction tuple has the same size as the ground truth tuple\n' ' (2) The elements with the same index in the two tuples are corresponding\n' ' (3) The first element of (pred & gt) will be used to train the FC discriminator\n' ' and calculate the SSL constraints\n' 'Please implement a new SSL algorithm if you want a variant of SSL_ADV with\n' 'multiple FC discriminators (for multiple predictions)\n') def _pred_err(self): logger.log_err( 'In SSL_ADV, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n')