def _validate(self, data_loader, epoch): self.meters.reset() self.model.eval() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._data_err() resulter, debugger = self.model.forward(inp, gt, False) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') task_loss = tool.dict_value(resulter, 'task_loss', err=True) task_loss = task_loss.mean() self.meters.update('task_loss', task_loss.data) self.task_func.metrics(activated_pred, gt, inp, self.meters, id_str='task') self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, False, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # metrics metrics_info = {'task': ''} for key in sorted(list(self.meters.keys())): if self.task_func.METRIC_STR in key: for id_str in metrics_info.keys(): if key.startswith(id_str): metrics_info[id_str] += '{0}: {1:.6}\t'.format( key, self.meters[key]) logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format( metrics_info['task'].replace('_', '-')))
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.model.train() self.d_model.train() # both 'inp' and 'gt' are tuples for idx, (inp, gt) in enumerate(data_loader): timer = time.time() inp, gt = self._batch_prehandle(inp, gt) if len(gt) > 1 and idx == 0: self._inp_warn() # ----------------------------------------------------------------------------- # step-1: train the task model # ----------------------------------------------------------------------------- self.optimizer.zero_grad() # forward the task model resulter, debugger = self.model.forward(inp) if not 'pred' in resulter.keys( ) or not 'activated_pred' in resulter.keys(): self._pred_err() pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') # forward the FC discriminator # 'confidence_map' is a tensor d_resulter, d_debugger = self.d_model.forward(activated_pred[0]) confidence_map = tool.dict_value(d_resulter, 'confidence') # calculate the supervised task constraint on the labeled data l_pred = func.split_tensor_tuple(pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_inp = func.split_tensor_tuple(inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size task_loss = self.criterion.forward(l_pred, l_gt, l_inp) task_loss = torch.mean(task_loss) self.meters.update('task_loss', task_loss.data) # calculate the adversarial constraint # calculate the adversarial constraint for the labeled data if self.args.adv_for_labeled: l_confidence_map = confidence_map[:lbs, ...] # preprocess prediction and ground truch for the adversarial constraint l_adv_confidence_map, l_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_confidence_map, l_gt[0], True) l_adv_loss = self.d_criterion(l_adv_confidence_map, l_adv_confidence_gt) labeled_adv_loss = self.args.labeled_adv_scale * torch.mean( l_adv_loss) self.meters.update('labeled_adv_loss', labeled_adv_loss.data) else: labeled_adv_loss = 0 self.meters.update('labeled_adv_loss', labeled_adv_loss) # calculate the adversarial constraint for the unlabeled data if self.args.unlabeled_batch_size > 0: u_confidence_map = confidence_map[lbs:self.args.batch_size, ...] # preprocess prediction and ground truch for the adversarial constraint u_adv_confidence_map, u_adv_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_confidence_map, None, True) u_adv_loss = self.d_criterion(u_adv_confidence_map, u_adv_confidence_gt) unlabeled_adv_loss = self.args.unlabeled_adv_scale * torch.mean( u_adv_loss) self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss.data) else: unlabeled_adv_loss = 0 self.meters.update('unlabeled_adv_loss', unlabeled_adv_loss) adv_loss = labeled_adv_loss + unlabeled_adv_loss # backward and update the task model loss = task_loss + adv_loss loss.backward() self.optimizer.step() # ----------------------------------------------------------------------------- # step-2: train the FC discriminator # ----------------------------------------------------------------------------- self.d_optimizer.zero_grad() # forward the task prediction (fake) if self.args.unlabeled_for_discriminator: fake_pred = activated_pred[0].detach() else: fake_pred = activated_pred[0][:lbs, ...].detach() d_resulter, d_debugger = self.d_model.forward(fake_pred) fake_confidence_map = tool.dict_value(d_resulter, 'confidence') l_fake_confidence_map = fake_confidence_map[:lbs, ...] l_fake_confidence_map, l_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(l_fake_confidence_map, l_gt[0], False) if self.args.unlabeled_for_discriminator and self.args.unlabeled_batch_size != 0: u_fake_confidence_map = fake_confidence_map[ lbs:self.args.batch_size, ...] u_fake_confidence_map, u_fake_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(u_fake_confidence_map, None, False) fake_confidence_map = torch.cat( (l_fake_confidence_map, u_fake_confidence_map), dim=0) fake_confidence_gt = torch.cat( (l_fake_confidence_gt, u_fake_confidence_gt), dim=0) else: fake_confidence_map, fake_confidence_gt = l_fake_confidence_map, l_fake_confidence_gt fake_d_loss = self.d_criterion.forward(fake_confidence_map, fake_confidence_gt) fake_d_loss = self.args.discriminator_scale * torch.mean( fake_d_loss) self.meters.update('fake_d_loss', fake_d_loss.data) # forward the ground truth (real) # convert the format of ground truch real_gt = self.task_func.ssladv_convert_task_gt_to_fcd_input( l_gt[0]) d_resulter, d_debugger = self.d_model.forward(real_gt) real_confidence_map = tool.dict_value(d_resulter, 'confidence') real_confidence_map, real_confidence_gt = \ self.task_func.ssladv_preprocess_fcd_criterion(real_confidence_map, l_gt[0], True) real_d_loss = self.d_criterion(real_confidence_map, real_confidence_gt) real_d_loss = self.args.discriminator_scale * torch.mean( real_d_loss) self.meters.update('real_d_loss', real_d_loss.data) # backward and update the FC discriminator d_loss = (fake_d_loss + real_d_loss) / 2 d_loss.backward() self.d_optimizer.step() # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'task-loss: {meters[task_loss]:.6f}\t' 'labeled-adv-loss: {meters[labeled_adv_loss]:.6f}\t' 'unlabeled-adv-loss: {meters[unlabeled_adv_loss]:.6f}\n' ' fc-discriminator\t=>\t' 'fake-d-loss: {meters[fake_d_loss]:.6f}\t' 'real-d-loss: {meters[real_d_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: u_inp_sample, u_pred_sample, u_cmap_sample = None, None, None if self.args.unlabeled_batch_size > 0: u_inp_sample = func.split_tensor_tuple(inp, lbs, lbs + 1, reduce_dim=True) u_pred_sample = func.split_tensor_tuple(activated_pred, lbs, lbs + 1, reduce_dim=True) u_cmap_sample = torch.sigmoid(fake_confidence_map[lbs]) self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True), torch.sigmoid(confidence_map[0]), u_inp_sample, u_pred_sample, u_cmap_sample) # the FC discriminator uses polynomiallr [ITER_LRERS] self.d_lrer.step() # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() original_lbs = int(self.args.labeled_batch_size / 2) original_bs = int(self.args.batch_size / 2) self.model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # the rotated samples are generated in the 'self._batch_prehandle' function # both 'inp' and 'gt' are tuples # the last element in the tuple 'gt' is the ground truth of the rotation angle inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) - 1 > 1 and idx == 0: self._inp_warn() self.optimizer.zero_grad() # forward the model resulter, debugger = self.model.forward(inp) pred = tool.dict_value(resulter, 'pred') activated_pred = tool.dict_value(resulter, 'activated_pred') pred_rotation = tool.dict_value(resulter, 'rotation') # calculate the supervised task constraint on the un-rotated labeled data l_pred = func.split_tensor_tuple(pred, 0, original_lbs) l_gt = func.split_tensor_tuple(gt, 0, original_lbs) l_inp = func.split_tensor_tuple(inp, 0, original_lbs) unrotated_task_loss = self.criterion.forward( l_pred, l_gt[:-1], l_inp) unrotated_task_loss = torch.mean(unrotated_task_loss) self.meters.update('unrotated_task_loss', unrotated_task_loss.data) # calculate the supervised task constraint on the rotated labeled data l_rotated_pred = func.split_tensor_tuple( pred, original_bs, original_bs + original_lbs) l_rotated_gt = func.split_tensor_tuple(gt, original_bs, original_bs + original_lbs) l_rotated_inp = func.split_tensor_tuple(inp, original_bs, original_bs + original_lbs) rotated_task_loss = self.criterion.forward(l_rotated_pred, l_rotated_gt[:-1], l_rotated_inp) rotated_task_loss = self.args.rotated_sup_scale * torch.mean( rotated_task_loss) self.meters.update('rotated_task_loss', rotated_task_loss.data) task_loss = unrotated_task_loss + rotated_task_loss # calculate the self-supervised rotation constraint rotation_loss = self.rotation_criterion.forward( pred_rotation, gt[-1]) rotation_loss = self.args.rotation_scale * torch.mean( rotation_loss) self.meters.update('rotation_loss', rotation_loss.data) # backward and update the model loss = task_loss + rotation_loss loss.backward() self.optimizer.step() # calculate the accuracy of the rotation classifier _, angle_idx = pred_rotation.topk(1, 1, True, True) angle_idx = angle_idx.t() rotation_acc = angle_idx.eq(gt[-1].view(1, -1).expand_as(angle_idx)) rotation_acc = rotation_acc.view(-1).float().sum( 0, keepdim=True).mul_(100.0 / self.args.batch_size) self.meters.update('rotation_acc', rotation_acc.data[0]) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' task-{3}\t=>\t' 'unrotated-task-loss: {meters[unrotated_task_loss]:.6f}\t' 'rotated-task-loss: {meters[rotated_task_loss]:.6f}\n' ' rotation-{3}\t=>\t' 'rotation-loss: {meters[rotation_loss]:.6f}\t' 'rotation-acc: {meters[rotation_acc]:.6f}\n'.format( epoch, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt[:-1], 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.lrer.step()
def _train(self, data_loader, epoch): self.meters.reset() lbs = self.args.labeled_batch_size self.s_model.train() self.t_model.train() for idx, (inp, gt) in enumerate(data_loader): timer = time.time() # 's_inp', 't_inp' and 'gt' are tuples s_inp, t_inp, gt = self._batch_prehandle(inp, gt, True) if len(gt) > 1 and idx == 0: self._inp_warn() # calculate the ramp-up coefficient of the consistency constraint cur_step = len(data_loader) * epoch + idx total_steps = len(data_loader) * self.args.cons_rampup_epochs cons_rampup_scale = func.sigmoid_rampup(cur_step, total_steps) self.s_optimizer.zero_grad() # forward the student model s_resulter, s_debugger = self.s_model.forward(s_inp) if not 'pred' in s_resulter.keys( ) or not 'activated_pred' in s_resulter.keys(): self._pred_err() s_pred = tool.dict_value(s_resulter, 'pred') s_activated_pred = tool.dict_value(s_resulter, 'activated_pred') # calculate the supervised task constraint on the labeled data l_s_pred = func.split_tensor_tuple(s_pred, 0, lbs) l_gt = func.split_tensor_tuple(gt, 0, lbs) l_s_inp = func.split_tensor_tuple(s_inp, 0, lbs) # 'task_loss' is a tensor of 1-dim & n elements, where n == batch_size s_task_loss = self.s_criterion.forward(l_s_pred, l_gt, l_s_inp) s_task_loss = torch.mean(s_task_loss) self.meters.update('s_task_loss', s_task_loss.data) # forward the teacher model with torch.no_grad(): t_resulter, t_debugger = self.t_model.forward(t_inp) if not 'pred' in t_resulter.keys(): self._pred_err() t_pred = tool.dict_value(t_resulter, 'pred') t_activated_pred = tool.dict_value(t_resulter, 'activated_pred') # calculate 't_task_loss' for recording l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs) l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs) t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp) t_task_loss = torch.mean(t_task_loss) self.meters.update('t_task_loss', t_task_loss.data) # calculate the consistency constraint from the teacher model to the student model t_pseudo_gt = Variable(t_pred[0].detach().data, requires_grad=False) if self.args.cons_for_labeled: cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt) elif self.args.unlabeled_batch_size > 0: cons_loss = self.cons_criterion(s_pred[0][lbs:, ...], t_pseudo_gt[lbs:, ...]) else: cons_loss = self.zero_tensor cons_loss = cons_rampup_scale * self.args.cons_scale * torch.mean( cons_loss) self.meters.update('cons_loss', cons_loss.data) # backward and update the student model loss = s_task_loss + cons_loss loss.backward() self.s_optimizer.step() # update the teacher model by EMA self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step) # logging self.meters.update('batch_time', time.time() - timer) if idx % self.args.log_freq == 0: logger.log_info( 'step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' ' student-{3}\t=>\t' 's-task-loss: {meters[s_task_loss]:.6f}\t' 's-cons-loss: {meters[cons_loss]:.6f}\n' ' teacher-{3}\t=>\t' 't-task-loss: {meters[t_task_loss]:.6f}\n'.format( epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters)) # visualization if self.args.visualize and idx % self.args.visual_freq == 0: self._visualize( epoch, idx, True, func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True), func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True), func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) # update iteration-based lrers if not self.args.is_epoch_lrer: self.s_lrer.step() # update epoch-based lrers if self.args.is_epoch_lrer: self.s_lrer.step()