def train(train_loader, model, optimizer, epoch, weak_mask=None, strong_mask=None): class_criterion = nn.BCELoss() [class_criterion] = to_cuda_if_available([class_criterion]) meters = AverageMeterSet() meters.update('lr', optimizer.param_groups[0]['lr']) LOG.debug("Nb batches: {}".format(len(train_loader))) start = time.time() for i, (batch_input, target) in enumerate(train_loader): [batch_input, target] = to_cuda_if_available([batch_input, target]) LOG.debug(batch_input.mean()) strong_pred, weak_pred = model(batch_input) loss = 0 if weak_mask is not None: # Weak BCE Loss # Trick to not take unlabeled data # Todo figure out another way target_weak = target.max(-2)[0] weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask]) if i == 1: LOG.debug("target: {}".format(target.mean(-2))) LOG.debug("Target_weak: {}".format(target_weak)) LOG.debug(weak_class_loss) meters.update('Weak loss', weak_class_loss.item()) loss += weak_class_loss if strong_mask is not None: # Strong BCE loss strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask]) meters.update('Strong loss', strong_class_loss.item()) loss += strong_class_loss assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() epoch_time = time.time() - start LOG.info('Epoch: {}\t' 'Time {:.2f}\t' '{meters}'.format(epoch, epoch_time, meters=meters))
def train(self, get_loss, get_acc, model_file, pretrain_file): if self.cfg.uda_mode or self.cfg.mixmatch_mode: ssl_mode = True else: ssl_mode = False """ train uda""" # tensorboardX logging if self.cfg.results_dir: dir = os.path.join('results', self.cfg.results_dir) if os.path.exists(dir) and os.path.isdir(dir): shutil.rmtree(dir) writer = SummaryWriter(log_dir=dir) #logger_path = dir + 'log.txt' #logger = Logger(logger_path, title='uda') #if self.cfg.no_unsup_loss: # logger.set_names(['Train Loss', 'Valid Acc', 'Valid Loss', 'LR']) #else: # logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U', 'Train Loss W U', 'Valid Acc', 'Valid Loss', 'LR']) meters = AverageMeterSet() self.model.train() if self.cfg.model == "custom": self.load(model_file, pretrain_file) # between model_file and pretrain_file, only one model will be loaded model = self.model.to(self.device) ema_model = self.ema_model.to(self.device) if self.ema_model else None if self.cfg.data_parallel: # Parallel GPU mode model = nn.DataParallel(model) ema_model = nn.DataParallel(ema_model) if ema_model else None global_step = 0 loss_sum = 0. max_acc = [0., 0, 0., 0.] # acc, step, val_loss, train_loss no_improvement = 0 sup_batch_size = None unsup_batch_size = None # Progress bar is set by unsup or sup data # uda_mode == True --> sup_iter is repeated # uda_mode == False --> sup_iter is not repeated iter_bar = tqdm(self.unsup_iter, total=self.cfg.total_steps, disable=self.cfg.hide_tqdm) if ssl_mode \ else tqdm(self.sup_iter, total=self.cfg.total_steps, disable=self.cfg.hide_tqdm) start = time.time() for i, batch in enumerate(iter_bar): # Device assignment if ssl_mode: sup_batch = [t.to(self.device) for t in next(self.sup_iter)] unsup_batch = [t.to(self.device) for t in batch] unsup_batch_size = unsup_batch_size or unsup_batch[0].shape[0] if unsup_batch[0].shape[0] != unsup_batch_size: continue else: sup_batch = [t.to(self.device) for t in batch] unsup_batch = None # update self.optimizer.zero_grad() final_loss, sup_loss, unsup_loss, weighted_unsup_loss = get_loss(model, sup_batch, unsup_batch, global_step) if self.cfg.no_sup_loss: final_loss = unsup_loss elif self.cfg.no_unsup_loss: final_loss = sup_loss meters.update('train_loss', final_loss.item()) meters.update('sup_loss', sup_loss.item()) meters.update('unsup_loss', unsup_loss.item()) meters.update('w_unsup_loss', weighted_unsup_loss.item()) meters.update('lr', self.optimizer.get_lr()[0]) final_loss.backward() self.optimizer.step() if self.ema_optimizer: self.ema_optimizer.step() # print loss global_step += 1 loss_sum += final_loss.item() if not self.cfg.hide_tqdm: if ssl_mode: iter_bar.set_description('final=%5.3f unsup=%5.3f sup=%5.3f'\ % (final_loss.item(), unsup_loss.item(), sup_loss.item())) else: iter_bar.set_description('loss=%5.3f' % (final_loss.item())) if global_step % self.cfg.save_steps == 0: self.save(global_step) if get_acc and global_step % self.cfg.check_steps == 0 and global_step > self.cfg.check_after: if self.cfg.mixmatch_mode: results = self.eval(get_acc, None, ema_model) else: total_accuracy, avg_val_loss = self.validate() # logging writer.add_scalars('data/eval_acc', {'eval_acc' : total_accuracy}, global_step) writer.add_scalars('data/eval_loss', {'eval_loss': avg_val_loss}, global_step) if self.cfg.no_unsup_loss: writer.add_scalars('data/train_loss', {'train_loss': meters['train_loss'].avg}, global_step) writer.add_scalars('data/lr', {'lr': meters['lr'].avg}, global_step) else: writer.add_scalars('data/train_loss', {'train_loss': meters['train_loss'].avg}, global_step) writer.add_scalars('data/sup_loss', {'sup_loss': meters['sup_loss'].avg}, global_step) writer.add_scalars('data/unsup_loss', {'unsup_loss': meters['unsup_loss'].avg}, global_step) writer.add_scalars('data/w_unsup_loss', {'w_unsup_loss': meters['w_unsup_loss'].avg}, global_step) writer.add_scalars('data/lr', {'lr': meters['lr'].avg}, global_step) meters.reset() if max_acc[0] < total_accuracy: self.save(global_step) max_acc = total_accuracy, global_step, avg_val_loss, final_loss.item() no_improvement = 0 else: no_improvement += 1 print(" Top 1 Accuracy: {0:.4f}".format(total_accuracy)) print(" Validation Loss: {0:.4f}".format(avg_val_loss)) print(" Train Loss: {0:.4f}".format(final_loss.item())) if ssl_mode: print(" Sup Loss: {0:.4f}".format(sup_loss.item())) print(" Unsup Loss: {0:.4f}".format(unsup_loss.item())) print(" Learning rate: {0:.7f}".format(self.optimizer.get_lr()[0])) print( 'Max Accuracy : %5.3f Best Val Loss : %5.3f Best Train Loss : %5.4f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[2], max_acc[3], max_acc[1], global_step), end='\n\n' ) if no_improvement == self.cfg.early_stopping: print("Early stopped") total_time = time.time() - start print('Total Training Time: %d' %(total_time), end='\n') break if self.cfg.total_steps and self.cfg.total_steps < global_step: print('The total steps have been reached') total_time = time.time() - start print('Total Training Time: %d' %(total_time), end='\n') if get_acc: if self.cfg.mixmatch_mode: results = self.eval(get_acc, None, ema_model) else: total_accuracy, avg_val_loss = self.validate() if max_acc[0] < total_accuracy: max_acc = total_accuracy, global_step, avg_val_loss, final_loss.item() print(" Top 1 Accuracy: {0:.4f}".format(total_accuracy)) print(" Validation Loss: {0:.2f}".format(avg_val_loss)) print(" Train Loss: {0:.2f}".format(final_loss.item())) print('Max Accuracy : %5.3f Best Val Loss : %5.3f Best Train Loss : %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[2], max_acc[3], max_acc[1], global_step), end='\n\n') self.save(global_step) return writer.close() return global_step
def train(cfg, train_loader, model, optimizer, epoch, ema_model=None, weak_mask=None, strong_mask=None): """ One epoch of a Mean Teacher model :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. Should return 3 values: teacher input, student input, labels :param model: torch.Module, model to be trained, should return a weak and strong prediction :param optimizer: torch.Module, optimizer used to train the model :param epoch: int, the current epoch of training :param ema_model: torch.Module, student model, should return a weak and strong prediction :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss) :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss) """ class_criterion = nn.BCELoss() consistency_criterion_strong = nn.MSELoss() lds_criterion = LDSLoss(xi=cfg.vat_xi, eps=cfg.vat_eps, n_power_iter=cfg.vat_n_power_iter) [class_criterion, consistency_criterion_strong, lds_criterion] = to_cuda_if_available( [class_criterion, consistency_criterion_strong, lds_criterion]) meters = AverageMeterSet() LOG.debug("Nb batches: {}".format(len(train_loader))) start = time.time() rampup_length = len(train_loader) * cfg.n_epoch // 2 for i, (batch_input, ema_batch_input, target) in enumerate(train_loader): global_step = epoch * len(train_loader) + i if global_step < rampup_length: rampup_value = ramps.sigmoid_rampup(global_step, rampup_length) else: rampup_value = 1.0 # Todo check if this improves the performance # adjust_learning_rate(optimizer, rampup_value, rampdown_value) meters.update('lr', optimizer.param_groups[0]['lr']) [batch_input, ema_batch_input, target] = to_cuda_if_available([batch_input, ema_batch_input, target]) LOG.debug(batch_input.mean()) # Outputs strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input) strong_pred_ema = strong_pred_ema.detach() weak_pred_ema = weak_pred_ema.detach() strong_pred, weak_pred = model(batch_input) loss = None # Weak BCE Loss # Take the max in axis 2 (assumed to be time) if len(target.shape) > 2: target_weak = target.max(-2)[0] else: target_weak = target if weak_mask is not None: weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask]) ema_class_loss = class_criterion(weak_pred_ema[weak_mask], target_weak[weak_mask]) if i == 0: LOG.debug("target: {}".format(target.mean(-2))) LOG.debug("Target_weak: {}".format(target_weak)) LOG.debug("Target_weak mask: {}".format( target_weak[weak_mask])) LOG.debug(weak_class_loss) LOG.debug("rampup_value: {}".format(rampup_value)) meters.update('weak_class_loss', weak_class_loss.item()) meters.update('Weak EMA loss', ema_class_loss.item()) loss = weak_class_loss # Strong BCE loss if strong_mask is not None: strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask]) meters.update('Strong loss', strong_class_loss.item()) strong_ema_class_loss = class_criterion( strong_pred_ema[strong_mask], target[strong_mask]) meters.update('Strong EMA loss', strong_ema_class_loss.item()) if loss is not None: loss += strong_class_loss else: loss = strong_class_loss # Teacher-student consistency cost if ema_model is not None: consistency_cost = cfg.max_consistency_cost * rampup_value meters.update('Consistency weight', consistency_cost) # Take only the consistence with weak and unlabel consistency_loss_strong = consistency_cost * consistency_criterion_strong( strong_pred, strong_pred_ema) meters.update('Consistency strong', consistency_loss_strong.item()) if loss is not None: loss += consistency_loss_strong else: loss = consistency_loss_strong meters.update('Consistency weight', consistency_cost) # Take only the consistence with weak and unlabel consistency_loss_weak = consistency_cost * consistency_criterion_strong( weak_pred, weak_pred_ema) meters.update('Consistency weak', consistency_loss_weak.item()) if loss is not None: loss += consistency_loss_weak else: loss = consistency_loss_weak # LDS loss if cfg.vat_enabled: lds_loss = cfg.vat_coeff * lds_criterion(model, batch_input, weak_pred) LOG.info('loss: {:.3f}, lds loss: {:.3f}'.format( loss, cfg.vat_coeff * lds_loss.detach().cpu().numpy())) loss += lds_loss else: if i % 25 == 0: LOG.info('loss: {:.3f}'.format(loss)) assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if ema_model is not None: update_ema_variables(model, ema_model, 0.999, global_step) epoch_time = time.time() - start LOG.info('Epoch: {}\t' 'Time {:.2f}\t' '{meters}'.format(epoch, epoch_time, meters=meters))
def train(train_loader, model, optimizer, epoch, ema_model=None, weak_mask=None, strong_mask=None): """ One epoch of a Mean Teacher model :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. Should return 3 values: teacher input, student input, labels :param model: torch.Module, model to be trained, should return a weak and strong prediction :param optimizer: torch.Module, optimizer used to train the model :param epoch: int, the current epoch of training :param ema_model: torch.Module, student model, should return a weak and strong prediction :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss) :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss) """ class_criterion = nn.BCELoss() ################################################## class_criterion1 = nn.BCELoss(reduction='none') ################################################## consistency_criterion = nn.MSELoss() # [class_criterion, consistency_criterion] = to_cuda_if_available( # [class_criterion, consistency_criterion]) [class_criterion, class_criterion1, consistency_criterion] = to_cuda_if_available( [class_criterion, class_criterion1, consistency_criterion]) meters = AverageMeterSet() LOG.debug("Nb batches: {}".format(len(train_loader))) start = time.time() rampup_length = len(train_loader) * cfg.n_epoch // 2 print("Train\n") # LOG.info("Weak[k] -> Weak[k]") # LOG.info("Weak[k] -> strong[k]") # print(weak_mask.start) # print(strong_mask.start) # exit() count = 0 check_cus_weak = 0 difficulty_loss = 0 loss_w = 1 LOG.info("loss paramater:{}".format(loss_w)) for i, (batch_input, ema_batch_input, target) in enumerate(train_loader): # print(batch_input.shape) # print(ema_batch_input.shape) # exit() global_step = epoch * len(train_loader) + i if global_step < rampup_length: rampup_value = ramps.sigmoid_rampup(global_step, rampup_length) else: rampup_value = 1.0 # Todo check if this improves the performance # adjust_learning_rate(optimizer, rampup_value, rampdown_value) meters.update('lr', optimizer.param_groups[0]['lr']) [batch_input, ema_batch_input, target] = to_cuda_if_available([batch_input, ema_batch_input, target]) LOG.debug("batch_input:{}".format(batch_input.mean())) # print(batch_input) # exit() # Outputs ################################################## # strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input) strong_pred_ema, weak_pred_ema, sof_ema = ema_model(ema_batch_input) sof_ema = sof_ema.detach() ################################################## strong_pred_ema = strong_pred_ema.detach() weak_pred_ema = weak_pred_ema.detach() ################################################## # strong_pred, weak_pred = model(batch_input) strong_pred, weak_pred, sof = model(batch_input) ################################################## ################################################## # custom_ema_loss = Custom_BCE_Loss(ema_batch_input, class_criterion1) if difficulty_loss == 0: LOG.info("############### Deffine Difficulty Loss ###############") difficulty_loss = 1 custom_ema_loss = Custom_BCE_Loss_difficulty(ema_batch_input, class_criterion1, paramater=loss_w) custom_ema_loss.initialize(strong_pred_ema, sof_ema) # custom_loss = Custom_BCE_Loss(batch_input, class_criterion1) custom_loss = Custom_BCE_Loss_difficulty(batch_input, class_criterion1, paramater=loss_w) custom_loss.initialize(strong_pred, sof) ################################################## # print(strong_pred.shape) # print(strong_pred) # print(weak_pred.shape) # print(weak_pred) # exit() loss = None # Weak BCE Loss # Take the max in the time axis # torch.set_printoptions(threshold=10000) # print(target[-10]) # # print(target.max(-2)) # # print(target.max(-2)[0]) # print(target.max(-1)[0][-10]) # exit() target_weak = target.max(-2)[0] if weak_mask is not None: weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask]) ema_class_loss = class_criterion(weak_pred_ema[weak_mask], target_weak[weak_mask]) print( "noraml_weak:", class_criterion(weak_pred[weak_mask], target_weak[weak_mask])) ################################################## custom_weak_class_loss = custom_loss.weak(target_weak, weak_mask) custom_ema_class_loss = custom_ema_loss.weak( target_weak, weak_mask) print("custom_weak:", custom_weak_class_loss) ################################################## count += 1 check_cus_weak += custom_weak_class_loss # print(custom_weak_class_loss.item()) if i == 0: LOG.debug("target: {}".format(target.mean(-2))) LOG.debug("Target_weak: {}".format(target_weak)) LOG.debug("Target_weak mask: {}".format( target_weak[weak_mask])) LOG.debug(custom_weak_class_loss) ### LOG.debug("rampup_value: {}".format(rampup_value)) meters.update('weak_class_loss', custom_weak_class_loss.item()) ### meters.update('Weak EMA loss', custom_ema_class_loss.item()) ### # loss = weak_class_loss loss = custom_weak_class_loss #################################################################################### # weak_class_loss = class_criterion(strong_pred[weak_mask], target[weak_mask]) # ema_class_loss = class_criterion(strong_pred_ema[weak_mask], target[weak_mask]) # # if i == 0: # # LOG.debug("target: {}".format(target.mean(-2))) # # LOG.debug("Target_weak: {}".format(target)) # # LOG.debug("Target_weak mask: {}".format(target[weak_mask])) # # LOG.debug(weak_class_loss) # # LOG.debug("rampup_value: {}".format(rampup_value)) # meters.update('weak_class_loss', weak_class_loss.item()) # meters.update('Weak EMA loss', ema_class_loss.item()) # loss = weak_class_loss #################################################################################### # Strong BCE loss if strong_mask is not None: strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask]) # meters.update('Strong loss', strong_class_loss.item()) strong_ema_class_loss = class_criterion( strong_pred_ema[strong_mask], target[strong_mask]) # meters.update('Strong EMA loss', strong_ema_class_loss.item()) print( "normal_strong:", class_criterion(strong_pred[strong_mask], target[strong_mask])) ################################################## custom_strong_class_loss = custom_loss.strong(target, strong_mask) meters.update('Strong loss', custom_strong_class_loss.item()) custom_strong_ema_class_loss = custom_ema_loss.strong( target, strong_mask) meters.update('Strong EMA loss', custom_strong_ema_class_loss.item()) print("custom_strong:", custom_strong_class_loss) ################################################## if loss is not None: # loss += strong_class_loss loss += custom_strong_class_loss else: # loss = strong_class_loss loss = custom_strong_class_loss # print("check_weak:", class_criterion1(weak_pred[weak_mask], target_weak[weak_mask]).mean()) # print("check_strong:", class_criterion1(strong_pred[strong_mask], target[strong_mask]).mean()) # print("\n") # exit() # Teacher-student consistency cost if ema_model is not None: consistency_cost = cfg.max_consistency_cost * rampup_value meters.update('Consistency weight', consistency_cost) # Take consistency about strong predictions (all data) consistency_loss_strong = consistency_cost * consistency_criterion( strong_pred, strong_pred_ema) meters.update('Consistency strong', consistency_loss_strong.item()) if loss is not None: loss += consistency_loss_strong else: loss = consistency_loss_strong meters.update('Consistency weight', consistency_cost) # Take consistency about weak predictions (all data) consistency_loss_weak = consistency_cost * consistency_criterion( weak_pred, weak_pred_ema) meters.update('Consistency weak', consistency_loss_weak.item()) if loss is not None: loss += consistency_loss_weak else: loss = consistency_loss_weak assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if ema_model is not None: update_ema_variables(model, ema_model, 0.999, global_step) epoch_time = time.time() - start LOG.info('Epoch: {}\t' 'Time {:.2f}\t' '{meters}'.format(epoch, epoch_time, meters=meters)) print("\ncheck_cus_weak:\n", check_cus_weak / count)