def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.l_model.train() self.r_model.train() self.fd_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() (l_inp, l_gt), (r_inp, r_gt) = self._batch_prehandle(inp, gt) if len(l_gt) == len(r_gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the dynamic consistency constraint cur_steps = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.dc_rampup_epochs dc_rampup_scale = func.sigmoid_rampup(cur_steps, total_steps) # ----------------------------------------------------------------------------- # step-0: pre-forwarding to save GPU memory # - forward the task models and the flaw detector # - generate pseudo ground truth for the unlabeled data if the dynamic # consistency constraint is enabled # ----------------------------------------------------------------------------- with torch.no_grad(): l_resulter, l_debugger = self.l_model.forward(l_inp) l_activated_pred = tool.dict_value(l_resulter, 'activated_pred') r_resulter, r_debugger = self.r_model.forward(r_inp) r_activated_pred = tool.dict_value(r_resulter, 'activated_pred') # 'l_flawmap' and 'r_flawmap' will be used in step-2 fd_resulter, fd_debugger = self.fd_model.forward( l_inp, l_activated_pred[0]) l_flawmap = tool.dict_value(fd_resulter, 'flawmap') fd_resulter, fd_debugger = self.fd_model.forward( r_inp, r_activated_pred[0]) r_flawmap = tool.dict_value(fd_resulter, 'flawmap') l_dc_gt, r_dc_gt = None, None l_fc_mask, r_fc_mask = None, None # generate the pseudo ground truth for the dynamic consistency constraint if self.args.ssl_mode in [MODE_GCT, MODE_DC]: with torch.no_grad(): l_handled_flawmap = self.flawmap_handler.forward(l_flawmap) r_handled_flawmap = self.flawmap_handler.forward(r_flawmap) l_dc_gt, r_dc_gt, l_fc_mask, r_fc_mask = self.dcgt_generator.forward( l_activated_pred[0].detach(), r_activated_pred[0].detach(), l_handled_flawmap, r_handled_flawmap) # ----------------------------------------------------------------------------- # step-1: train the task models # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = False # train the 'l' task model l_loss = self._task_model_iter(epoch, idx, True, 'l', lbs, l_inp, l_gt, l_dc_gt, l_fc_mask, dc_rampup_scale) self.l_optimizer.zero_grad() l_loss.backward() self.l_optimizer.step() # train the 'r' task model r_loss = self._task_model_iter(epoch, idx, True, 'r', lbs, r_inp, r_gt, r_dc_gt, r_fc_mask, dc_rampup_scale) self.r_optimizer.zero_grad() r_loss.backward() self.r_optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the flaw detector # ----------------------------------------------------------------------------- for param in self.fd_model.parameters(): param.requires_grad = True # generate the ground truth for the flaw detector (on labeled data only) with torch.no_grad(): l_flawmap_gt = self.fdgt_generator.forward( l_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( l_gt[0][:lbs, ...])) r_flawmap_gt = self.fdgt_generator.forward( r_activated_pred[0][:lbs, ...].detach(), self.task_func.sslgct_prepare_task_gt_for_fdgt( r_gt[0][:lbs, ...])) l_fd_loss = self.fd_criterion.forward(l_flawmap[:lbs, ...], l_flawmap_gt) l_fd_loss = self.args.fd_scale * torch.mean(l_fd_loss) self.meters.update('l_fd_loss', l_fd_loss.data) r_fd_loss = self.fd_criterion.forward(r_flawmap[:lbs, ...], r_flawmap_gt) r_fd_loss = self.args.fd_scale * torch.mean(r_fd_loss) self.meters.update('r_fd_loss', r_fd_loss.data) fd_loss = (l_fd_loss + r_fd_loss) / 2 self.fd_optimizer.zero_grad() fd_loss.backward() self.fd_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' l-{3}\t=>\t' 'l-task-loss: {meters[l_task_loss]:.6f}\t' 'l-dc-loss: {meters[l_dc_loss]:.6f}\t' 'l-fc-loss: {meters[l_fc_loss]:.6f}\n' ' r-{3}\t=>\t' 'r-task-loss: {meters[r_task_loss]:.6f}\t' 'r-dc-loss: {meters[r_dc_loss]:.6f}\t' 'r-fc-loss: {meters[r_fc_loss]:.6f}\n' ' fd\t=>\t' 'l-fd-loss: {meters[l_fd_loss]:.6f}\t' 'r-fd-loss: {meters[r_fd_loss]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # the flaw detector uses polynomiallr [ITER_LRERS] self.fd_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.l_lrer.step() self.r_lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size ubs = self.args.unlabeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 'inp' and 'gt' are tuples inp, gt, mix_u_inp, mix_u_mask = self._batch_prehandle( inp, gt, True) if len(inp) > 1 and idx == 0: self._inp_warn() if len(gt) > 1 and idx == 0: self._gt_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # ------------------------------------------------- # For Labeled Samples # ------------------------------------------------- l_inp = func.split_tensor_tuple(inp, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) # forward the labeled samples by the student model l_s_resulter, l_s_debugger = self.s_model.forward(l_inp) if not 'pred' in l_s_resulter.keys( ) or not 'activated_pred' in l_s_resulter.keys(): self._pred_err() l_s_pred = tool.dict_value(l_s_resulter, 'pred') l_s_activated_pred = tool.dict_value(l_s_resulter, 'activated_pred') # calculate the supervised task loss on the labeled samples task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # ------------------------------------------------- # For Unlabeled Samples # ------------------------------------------------- if self.args.unlabeled_batch_size > 0: u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the original samples by the teacher model with torch.no_grad(): u_t_resulter, u_t_debugger = self.t_model.forward(u_inp) if not 'pred' in u_t_resulter.keys( ) or not 'activated_pred' in u_t_resulter.keys(): self._pred_err() u_t_activated_pred = tool.dict_value(u_t_resulter, 'activated_pred') # mix the activated pred from the teacher model as the pseudo gt u_t_activated_pred_1 = func.split_tensor_tuple( u_t_activated_pred, 0, int(ubs / 2)) u_t_activated_pred_2 = func.split_tensor_tuple( u_t_activated_pred, int(ubs / 2), ubs) mix_u_t_activated_pred = [] mix_u_t_confidence = [] for up_1, up_2 in zip(u_t_activated_pred_1, u_t_activated_pred_2): mp = mix_u_mask * up_1 + (1 - mix_u_mask) * up_2 mix_u_t_activated_pred.append(mp.detach()) # NOTE: here we just follow the official code of CutMix to calculate the confidence # but it is odd that all the samples use the same confidence (mean confidence) u_t_confidence = (mp.max(dim=1)[0] > self.args.cons_threshold).float().mean() mix_u_t_confidence.append(u_t_confidence.detach()) mix_u_t_activated_pred = tuple(mix_u_t_activated_pred) # forward the mixed samples by the student model u_s_resulter, u_s_debugger = self.s_model.forward(mix_u_inp) if not 'pred' in u_s_resulter.keys( ) or not 'activated_pred' in u_s_resulter.keys(): self._pred_err() mix_u_s_activated_pred = tool.dict_value( u_s_resulter, 'activated_pred') # calculate the consistency constraint cons_loss = 0 for msap, mtap, confidence in zip(mix_u_s_activated_pred, mix_u_t_activated_pred, mix_u_t_confidence): cons_loss += torch.mean(self.cons_criterion( msap, mtap)) * confidence cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the student model loss = task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(l_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(l_gt, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(mix_u_t_activated_pred, 0, 1, reduce_dim=True), mix_u_mask[0]) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() # TODO: support more ramp-up functions # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.optimizer.zero_grad() # ----------------------------------------------------------- # For Labeled Data # ----------------------------------------------------------- l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # forward the wrapped CCT model resulter, debugger = self.model.forward(l_inp, l_gt, False) l_pred = tool.dict_value(resulter, 'pred') l_activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) # ----------------------------------------------------------- # For Unlabeled Data # ----------------------------------------------------------- if self.args.unlabeled_batch_size > 0: ul_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size) ul_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) # forward the wrapped CCT model resulter, debugger = self.model.forward(ul_inp, ul_gt, True) ul_pred = tool.dict_value(resulter, 'pred') ul_activated_pred = tool.dict_value(resulter, 'activated_pred') ul_ad_preds = tool.dict_value(resulter, 'ul_ad_preds') cons_loss = tool.dict_value(resulter, 'cons_loss', err=True) cons_loss = cons_loss.mean() cons_loss = cons_rampup_scale * self.args.cons_scale * cons_loss self.meters.update('cons_loss', cons_loss.data) else: cons_loss = 0 self.meters.update('cons_loss', cons_loss) # backward and update the wrapped CCT model loss = task_loss + cons_loss loss.backward() self.optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'cons-loss: {meters[cons_loss]:.6f}\n' .format(epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize(epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 's_inp', 't_inp' and 'gt' are tuples s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # forward the student model s_resulter, s_debugger = self.s_model.forward(s_inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) # forward the teacher model with torch.no_grad(): t_resulter, t_debugger = self.t_model.forward(t_inp) if not 'pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') # calculate 't_task_loss' for recording l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs) l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs) t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) # calculate the consistency constraint from the teacher model to the student model t_pseudo_gt = Variable(t_pred[0].detach().data, requires_grad=False) if self.args.cons_for_labeled: cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt) elif self.args.unlabeled_batch_size > 0: cons_loss = self.cons_criterion(s_pred[0][lbs:, ...], t_pseudo_gt[lbs:, ...]) else: cons_loss = self.zero_tensor cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) # backward and update the student model loss = s_task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()