def forward(self, inp): resulter, debugger = {}, {} t_resulter, t_debugger = self.task_model.forward(inp) if not 'pred' in t_resulter.keys( ) or not 'activated_pred' in t_resulter.keys(): logger.log_err( 'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n') if not 'ssls4l_rc_inp' in t_resulter.keys(): logger.log_err( 'In SSL_S4L, the \'resulter\' dict returned by the task model should contain the key:\n' ' \'ssls4l_rc_inp\'\t=>\tinputs of the rotation classifier (a 4-dim tensor)\n' 'It can be the feature map encoded by the task model or the output of the task model\n' 'Please add the key \'ssls4l_rc_inp\' in your task model\'s resulter\n' ) rc_inp = tool.dict_value(t_resulter, 'ssls4l_rc_inp') pred_rotation = self.rotation_classifier.forward(rc_inp) resulter['pred'] = tool.dict_value(t_resulter, 'pred') resulter['activated_pred'] = tool.dict_value(t_resulter, 'activated_pred') resulter['rotation'] = pred_rotation return resulter, debugger
def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() resulter, debugger = self.model.forward(inp, gt, False) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-')))
def forward(self, inp, gt, is_unlabeled): resulter, debugger = {}, {} # forward the task model m_resulter, m_debugger = self.main_model.forward(inp) if not 'pred' in m_resulter.keys() or not 'activated_pred' in m_resulter.keys(): logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the following keys:\n' ' (1) \'pred\'\t=>\tunactivated task predictions\n' ' (2) \'activated_pred\'\t=>\tactivated task predictions\n' 'We need both of them since some losses include the activation functions,\n' 'e.g., the CrossEntropyLoss has contained SoftMax\n') resulter['pred'] = tool.dict_value(m_resulter, 'pred') resulter['activated_pred'] = tool.dict_value(m_resulter, 'activated_pred') if not len(resulter['pred']) == len(resulter['activated_pred']) == 1: logger.log_err('This implementation of SSL_CCT only support the task model with only one prediction (output). \n' 'However, there are {0} predictions.\n'.format(len(resulter['pred']))) # calculate the task loss resulter['task_loss'] = None if is_unlabeled else torch.mean(self.task_criterion.forward(resulter['pred'], gt, inp)) # for the unlabeled data if is_unlabeled and self.args.unlabeled_batch_size > 0: if not 'sslcct_ad_inp' in m_resulter.keys(): logger.log_err('In SSL_CCT, the \'resulter\' dict returned by the task model should contain the key:\n' ' \'sslcct_ad_inp\'\t=>\tinputs of the auxiliary decoders (a 4-dim tensor)\n' 'It is the feature map encoded by the task model\n' 'Please add the key \'sslcct_ad_inp\' in your task model\'s resulter\n' 'Note that for different task models, the shape of \'sslcct_ad_inp\' may be different\n') ul_ad_inp = tool.dict_value(m_resulter, 'sslcct_ad_inp') ul_main_pred = resulter['pred'][0].detach() # forward the auxiliary decoders ul_ad_preds = [] for ad in self.auxiliary_decoders: ul_ad_preds.append(ad.forward(ul_ad_inp, pred_of_main_decoder=ul_main_pred)) resulter['ul_ad_preds'] = ul_ad_preds # calculate the consistency loss ul_ad_gt = resulter['activated_pred'][0].detach() ul_ad_preds = [F.interpolate(ul_ad_pred, size=(ul_ad_gt.shape[2], ul_ad_gt.shape[3]), mode='bilinear') for ul_ad_pred in ul_ad_preds] ul_activated_ad_preds = self.ad_activation_func(ul_ad_preds) cons_loss = sum([self.cons_criterion.forward(ul_activated_ad_pred, ul_ad_gt) for ul_activated_ad_pred in ul_activated_ad_preds]) cons_loss = torch.mean(cons_loss) / len(ul_activated_ad_preds) resulter['cons_loss'] = cons_loss else: resulter['ul_ad_preds'] = None resulter['cons_loss'] = None return resulter, debugger
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err('Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.s_model.load_state_dict(checkpoint['s_model']) self.t_model.load_state_dict(checkpoint['t_model']) self.s_optimizer.load_state_dict(checkpoint['s_optimizer']) self.s_lrer.load_state_dict(checkpoint['s_lrer']) return checkpoint['epoch']
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err('Unmatched SSL algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lrer.load_state_dict(checkpoint['lrer']) self.main_model = self.model.module.main_model self.auxiliary_decoders = self.model.module.auxiliary_decoders return checkpoint['epoch']
def _load_checkpoint(self): checkpoint = torch.load(self.args.resume) checkpoint_algorithm = tool.dict_value(checkpoint, 'algorithm', default='unknown') if checkpoint_algorithm != self.NAME: logger.log_err( 'Unmatched ssl algorithm format in checkpoint => required: {0} - given: {1}\n' .format(self.NAME, checkpoint_algorithm)) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lrer.load_state_dict(checkpoint['lrer']) self.task_model = self.model.module.task_model self.rotation_classifier = self.model.module.rotation_classifier return checkpoint['epoch']
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.train() self.r_model.train() self.fd_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the dynamic consistency constraint cur_steps = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.dc_rampup_epochs dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps) # ----------------------------------------------------------------------------- # step-0: pre-forwarding to save GPU memory # - forward the task models and the flaw detector # - generate pseudo ground truth for the unlabeled data if the dynamic # consistency constraint is enabled # ----------------------------------------------------------------------------- with torch.no_grad(): l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') # 'l_flawmap' and 'r_flawmap' will be used in step-2 fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None # generate the pseudo ground truth for the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: with torch.no_grad(): l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) # ----------------------------------------------------------------------------- # step-1: train the task models # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = False # train the 'l' task model l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, dc_rampup_scale) self.l_optimizer.zero_grad() l_loss.backward() self.l_optimizer.step() # train the 'r' task model r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, dc_rampup_scale) self.r_optimizer.zero_grad() r_loss.backward() self.r_optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the flaw detector # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = True # generate the ground truth for the flaw detector (on labeled data only) with torch.no_grad(): l_flawmap_gt = self.fdgt_generator.forward( l_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( l_gt[0][:lbs, ...])) r_flawmap_gt = self.fdgt_generator.forward( r_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( r_gt[0][:lbs, ...])) l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...], l_flawmap_gt) l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss) self.meters.update('l_fd_loss', l_fd_loss.data) r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...], r_flawmap_gt) r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss) self.meters.update('r_fd_loss', r_fd_loss.data) fd_loss = (l_fd_loss + r_fd_loss) / 2 self.fd_optimizer.zero_grad() fd_loss.backward() self.fd_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # the flaw detector uses polynomiallr [ITER_LRERS] self.fd_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step()
def _train(self, data_loader, epoch): # disable unlabeled data without_unlabeled_data = self.args.ignore_unlabeled and self.args.unlabeled_batch_size == 0 if not without_unlabeled_data: logger.log_err( 'SSL_NULL is a supervised-only algorithm\n' 'Please set ignore_unlabeled = True and unlabeled_batch_size = 0\n' ) self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # both 'inp' and 'gt' are tuples inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # backward and update the task model loss = task_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() # TODO: support more ramp-up functions # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.optimizer.zero_grad() # ----------------------------------------------------------- # For Labeled Data # ----------------------------------------------------------- l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # forward the wrapped CCT model resulter, debugger = self.model.forward(l_inp, l_gt, False) l_pred = tool.dict_value(resulter, 'pred') l_activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) # ----------------------------------------------------------- # For Unlabeled Data # ----------------------------------------------------------- if self.args.unlabeled_batch_size > 0: ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size) ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the wrapped CCT model resulter, debugger = self.model.forward(ul_inp, ul_gt, True) ul_pred = tool.dict_value(resulter, 'pred') ul_activated_pred = tool.dict_value(resulter, 'activated_pred') ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds') cons_loss = tool.dict_value(resulter, 'cons_loss', err=True) cons_loss = cons_loss.mean() cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the wrapped CCT model loss = task_loss + cons_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'cons-loss: {meters[cons_loss]:.6f}\n' .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize(epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _validate(self, data_loader, epoch): self.meters.reset() self.s_model.eval() self.t_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt, _, _ = self._batch_prehandle(inp, gt, False) if len(inp) > 1 and idx == 0: self._inp_warn() if len(gt) > 1 and idx == 0: self._gt_warn() s_resulter, s_debugger = self.s_model.forward(inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') s_task_loss = self.s_criterion.forward(s_pred, gt, inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) t_resulter, t_debugger = self.t_model.forward(inp) if not 'pred' in t_resulter.keys( ) or not 'activated_pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') t_task_loss = self.s_criterion.forward(t_pred, gt, inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) t_pseudo_gt = [] for tap in t_activated_pred: t_pseudo_gt.append(tap.detach()) t_pseudo_gt = tuple(t_pseudo_gt) cons_loss = 0 for sap, tpg in zip(s_activated_pred, t_pseudo_gt): cons_loss += torch.mean(self.cons_criterion(sap, tpg)) cons_loss = self.args.cons_scale * torch.mean(cons_loss) self.meters.update('cons_loss', cons_loss.data) self.task_func.metrics(s_activated_pred, gt, inp, self.meters, id_str='student') self.task_func.metrics(t_activated_pred, gt, inp, self.meters, id_str='teacher') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # metrics metrics_info = {'student': '', 'teacher': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info( 'Validation metrics:\n student-metrics\t=>\t{0}\n teacher-metrics\t=>\t{1}\n' .format(metrics_info['student'].replace('_', '-'), metrics_info['teacher'].replace('_', '-')))
def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() self.d_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = self.criterion.forward(pred, gt, inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) unhandled_fake_confidence_map = tool.dict_value( d_resulter, 'confidence') fake_confidence_map, fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_fake_confidence_map, gt[0], False) fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input(gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) unhandled_real_confidence_map = tool.dict_value( d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(unhandled_real_confidence_map, gt[0], True) real_d_loss = self.d_criterion.forward(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(unhandled_fake_confidence_map[0])) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-')))
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() self.d_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() # ----------------------------------------------------------------------------- # step-1: train the task model # ----------------------------------------------------------------------------- self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # forward the FC discriminator # 'confidence_map' is a tensor d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) confidence_map = tool.dict_value(d_resulter, 'confidence') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # calculate the adversarial constraint # calculate the adversarial constraint for the labeled data if self.args.adv_for_labeled: l_confidence_map = confidence_map[:lbs, ...] # preprocess prediction and ground truch for the adversarial constraint l_adv_confidence_map, l_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True) l_adv_loss = self.d_criterion(l_adv_confidence_map, l_adv_confidence_gt) labeled_adv_loss = self.args.labeled_adv_scale * torch.mean( l_adv_loss) self.meters.update('labeled_adv_loss', labeled_adv_loss.data) else: labeled_adv_loss = 0 self.meters.update('labeled_adv_loss', labeled_adv_loss) # calculate the adversarial constraint for the unlabeled data if self.args.unlabeled_batch_size > 0: u_confidence_map = confidence_map[lbs:self.args.batch_size, ...] # preprocess prediction and ground truch for the adversarial constraint u_adv_confidence_map, u_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True) u_adv_loss = self.d_criterion(u_adv_confidence_map, u_adv_confidence_gt) unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean( u_adv_loss) self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss.data) else: unlabeled_adv_loss = 0 self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss) adv_loss = labeled_adv_loss + unlabeled_adv_loss # backward and update the task model loss = task_loss + adv_loss loss.backward() self.optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the FC discriminator # ----------------------------------------------------------------------------- self.d_optimizer.zero_grad() # forward the task prediction (fake) if self.args.unlabeled_for_discriminator: fake_pred = activated_pred[0].detach() else: fake_pred = activated_pred[0][:lbs, ...].detach() d_resulter, d_debugger = self.d_model.forward(fake_pred) fake_confidence_map = tool.dict_value(d_resulter, 'confidence') l_fake_confidence_map = fake_confidence_map[:lbs, ...] l_fake_confidence_map, l_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False) if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0: u_fake_confidence_map = fake_confidence_map[ lbs:self.args.batch_size, ...] u_fake_confidence_map, u_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False) fake_confidence_map = torch.cat( (l_fake_confidence_map, u_fake_confidence_map), dim=0) fake_confidence_gt = torch.cat( (l_fake_confidence_gt, u_fake_confidence_gt), dim=0) else: fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) # forward the ground truth (real) # convert the format of ground truch real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input( l_gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) real_confidence_map = tool.dict_value(d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True) real_d_loss = self.d_criterion(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) # backward and update the FC discriminator d_loss = (fake_d_loss + real_d_loss) / 2 d_loss.backward() self.d_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t' 'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None if self.args.unlabeled_batch_size > 0: u_inp_sample = func.split_tensor_tuple(inp, lbs, lbs + 1, reduce_dim=True) u_pred_sample = func.split_tensor_tuple(activated_pred, lbs, lbs + 1, reduce_dim=True) u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs]) self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(confidence_map[0]), u_inp_sample, u_pred_sample, u_cmap_sample) # the FC discriminator uses polynomiallr [ITER_LRERS] self.d_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() original_lbs = int(self.args.labeled_batch_size / 2) original_bs = int(self.args.batch_size / 2) self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # the rotated samples are generated in the 'self._batch_prehandle' function # both 'inp' and 'gt' are tuples # the last element in the tuple 'gt' is the ground truth of the rotation angle inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) - 1 > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the model resulter, debugger = self.model.forward(inp) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') pred_rotation = tool.dict_value(resulter, 'rotation') # calculate the supervised task constraint on the un-rotated labeled data l_pred = func.split_tensor_tuple(pred, 0, original_lbs) l_gt = func.split_tensor_tuple(gt, 0, original_lbs) l_inp = func.split_tensor_tuple(inp, 0, original_lbs) unrotated_task_loss = self.criterion.forward( l_pred, l_gt[:-1], l_inp) unrotated_task_loss = torch.mean(unrotated_task_loss) self.meters.update('unrotated_task_loss', unrotated_task_loss.data) # calculate the supervised task constraint on the rotated labeled data l_rotated_pred = func.split_tensor_tuple( pred, original_bs, original_bs + original_lbs) l_rotated_gt = func.split_tensor_tuple(gt, original_bs, original_bs + original_lbs) l_rotated_inp = func.split_tensor_tuple(inp, original_bs, original_bs + original_lbs) rotated_task_loss = self.criterion.forward(l_rotated_pred, l_rotated_gt[:-1], l_rotated_inp) rotated_task_loss = self.args.rotated_sup_scale * torch.mean( rotated_task_loss) self.meters.update('rotated_task_loss', rotated_task_loss.data) task_loss = unrotated_task_loss + rotated_task_loss # calculate the self-supervised rotation constraint rotation_loss = self.rotation_criterion.forward( pred_rotation, gt[-1]) rotation_loss = self.args.rotation_scale * torch.mean( rotation_loss) self.meters.update('rotation_loss', rotation_loss.data) # backward and update the model loss = task_loss + rotation_loss loss.backward() self.optimizer.step() # calculate the accuracy of the rotation classifier _, angle_idx = pred_rotation.topk(1, 1, True, True) angle_idx = angle_idx.t() rotation_acc = angle_idx.eq(gt[-1].view(1, -1).expand_as(angle_idx)) rotation_acc = rotation_acc.view(-1).float().sum( 0, keepdim=True).mul_(100.0 / self.args.batch_size) self.meters.update('rotation_acc', rotation_acc.data[0]) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t' 'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n' ' rotation-{3}\t=>\t' 'rotation-loss: {meters[rotation_loss]:.6f}\t' 'rotation-acc: {meters[rotation_acc]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 's_inp', 't_inp' and 'gt' are tuples s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # forward the student model s_resulter, s_debugger = self.s_model.forward(s_inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) # forward the teacher model with torch.no_grad(): t_resulter, t_debugger = self.t_model.forward(t_inp) if not 'pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') # calculate 't_task_loss' for recording l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs) l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs) t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) # calculate the consistency constraint from the teacher model to the student model t_pseudo_gt = Variable(t_pred[0].detach().data, requires_grad=False) if self.args.cons_for_labeled: cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt) elif self.args.unlabeled_batch_size > 0: cons_loss = self.cons_criterion(s_pred[0][lbs:, ...], t_pseudo_gt[lbs:, ...]) else: cons_loss = self.zero_tensor cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) # backward and update the student model loss = s_task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()
def _validate(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.eval() self.r_model.eval() self.fd_model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None if self.args.ssl_mode in [MODE_GCT, MODE_DC]: l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) l_loss = self._task_model_iter(epoch, idx, False, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, 1) r_loss = self._task_model_iter(epoch, idx, False, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, 1) self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # metrics metrics_info = {'l': '', 'r': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6f}\t'.format( key, self.meters[key]) logger.log_info( 'Validation metrics:\n l-metrics\t=>\t{0}\n r-metrics\t=>\t{1}\n' .format(metrics_info['l'].replace('_', '-'), metrics_info['r'].replace('_', '-')))
def _task_model_iter(self, epoch, idx, is_train, mid, lbs, inp, gt, dc_gt, fc_mask, dc_rampup_scale): if mid == 'l': model, criterion = self.l_model, self.l_criterion elif mid == 'r': model, criterion = self.r_model, self.r_criterion else: model, criterion = None, None # forward the task model resulter, debugger = model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') fd_resulter, fd_debugger = self.fd_model.forward( inp, activated_pred[0]) flawmap = tool.dict_value(fd_resulter, 'flawmap') # calculate the supervised task constraint on the labeled data labeled_pred = func.split_tensor_tuple(pred, 0, lbs) labeled_gt = func.split_tensor_tuple(gt, 0, lbs) labeled_inp = func.split_tensor_tuple(inp, 0, lbs) task_loss = torch.mean( criterion.forward(labeled_pred, labeled_gt, labeled_inp)) self.meters.update('{0}_task_loss'.format(mid), task_loss.data) # calculate the flaw correction constraint if self.args.ssl_mode in [MODE_GCT, MODE_FC]: if flawmap.shape == self.zero_df_gt.shape: fc_ssl_loss = self.fd_criterion.forward(flawmap, self.zero_df_gt, is_ssl=True, reduction=False) else: fc_ssl_loss = self.fd_criterion.forward( flawmap, torch.zeros(flawmap.shape).cuda(), is_ssl=True, reduction=False) if self.args.ssl_mode == MODE_GCT: fc_ssl_loss = fc_mask * fc_ssl_loss fc_ssl_loss = self.args.fc_ssl_scale * torch.mean(fc_ssl_loss) self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss.data) else: fc_ssl_loss = 0 self.meters.update('{0}_fc_loss'.format(mid), fc_ssl_loss) # calculate the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: if dc_gt is None: logger.log_err( 'The dynamic consistency constraint is enabled, ' 'but no pseudo ground truth is given.') dc_ssl_loss = self.dc_criterion.forward(activated_pred[0], dc_gt) dc_ssl_loss = dc_rampup_scale * self.args.dc_ssl_scale * torch.mean( dc_ssl_loss) self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss.data) else: dc_ssl_loss = 0 self.meters.update('{0}_dc_loss'.format(mid), dc_ssl_loss) with torch.no_grad(): flawmap_gt = self.fdgt_generator.forward( activated_pred[0], self.task_func.sslgct_prepare_task_gt_for_fdgt(gt[0])) # for validation if not is_train: fd_loss = self.args.fd_scale * self.fd_criterion.forward( flawmap, flawmap_gt) self.meters.update('{0}_fd_loss'.format(mid), torch.mean(fd_loss).data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str=mid) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: with torch.no_grad(): handled_flawmap = self.flawmap_handler(flawmap)[0] self._visualize( epoch, idx, is_train, mid, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), handled_flawmap, flawmap_gt[0], dc_gt[0]) loss = task_loss + fc_ssl_loss + dc_ssl_loss return loss
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size ubs = self.args.unlabeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 'inp' and 'gt' are tuples inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle( inp, gt, True) if len(inp) > 1 and idx == 0: self._inp_warn() if len(gt) > 1 and idx == 0: self._gt_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # ------------------------------------------------- # For Labeled Samples # ------------------------------------------------- l_inp = func.split_tensor_tuple(inp, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) # forward the labeled samples by the student model l_s_resulter, l_s_debugger = self.s_model.forward(l_inp) if not 'pred' in l_s_resulter.keys( ) or not 'activated_pred' in l_s_resulter.keys(): self._pred_err() l_s_pred = tool.dict_value(l_s_resulter, 'pred') l_s_activated_pred = tool.dict_value(l_s_resulter, 'activated_pred') # calculate the supervised task loss on the labeled samples task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # ------------------------------------------------- # For Unlabeled Samples # ------------------------------------------------- if self.args.unlabeled_batch_size > 0: u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the original samples by the teacher model with torch.no_grad(): u_t_resulter, u_t_debugger = self.t_model.forward(u_inp) if not 'pred' in u_t_resulter.keys( ) or not 'activated_pred' in u_t_resulter.keys(): self._pred_err() u_t_activated_pred = tool.dict_value(u_t_resulter, 'activated_pred') # mix the activated pred from the teacher model as the pseudo gt u_t_activated_pred_1 = func.split_tensor_tuple( u_t_activated_pred, 0, int(ubs / 2)) u_t_activated_pred_2 = func.split_tensor_tuple( u_t_activated_pred, int(ubs / 2), ubs) mix_u_t_activated_pred = [] mix_u_t_confidence = [] for up_1, up_2 in zip(u_t_activated_pred_1, u_t_activated_pred_2): mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2 mix_u_t_activated_pred.append(mp.detach()) # NOTE: here we just follow the official code of CutMix to calculate the confidence # but it is odd that all the samples use the same confidence (mean confidence) u_t_confidence = (mp.max(dim=1)[0] > self.args.cons_threshold).float().mean() mix_u_t_confidence.append(u_t_confidence.detach()) mix_u_t_activated_pred = tuple(mix_u_t_activated_pred) # forward the mixed samples by the student model u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp) if not 'pred' in u_s_resulter.keys( ) or not 'activated_pred' in u_s_resulter.keys(): self._pred_err() mix_u_s_activated_pred = tool.dict_value( u_s_resulter, 'activated_pred') # calculate the consistency constraint cons_loss = 0 for msap, mtap, confidence in zip(mix_u_s_activated_pred, mix_u_t_activated_pred, mix_u_t_confidence): cons_loss += torch.mean(self.cons_criterion( msap, mtap)) * confidence cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the student model loss = task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_t_activated_pred, 0, 1, reduce_dim=True), mix_u_mask[0]) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()