def train(self): console_header = 'Epoch\tTrain_Loss\tTrain_Accuracy\tTest_Accuracy\tEpoch_Runtime\tLearning_Rate' print_to_console(console_header) print_to_logfile(self._logfile, console_header, init=True) for t in range(self._start_epoch, self._epochs): epoch_start = time.time() self._scheduler.step(epoch=t) # reset average meters self._train_loss.reset() self._train_accuracy.reset() self._net.train(True) self.single_epoch_training(t) test_accuracy = evaluate(self._test_loader, self._net) lr = get_lr_from_optimizer(self._optimizer) if test_accuracy > self._best_accuracy: self._best_accuracy = test_accuracy self._best_epoch = t + 1 torch.save(self._net.state_dict(), 'model/step{}_best_epoch.pth'.format(self._step)) # print('*', end='') epoch_end = time.time() single_epoch_runtime = epoch_end - epoch_start # Logging console_content = '{:05d}\t{:10.4f}\t{:14.4f}\t{:13.4f}\t{:13.2f}\t{:13.1e}'.format( t + 1, self._train_loss.avg, self._train_accuracy.avg, test_accuracy, single_epoch_runtime, lr) print_to_console(console_content) print_to_logfile(self._logfile, console_content, init=False) # save checkpoint save_checkpoint({ 'epoch': t + 1, 'state_dict': self._net.state_dict(), 'best_epoch': self._best_epoch, 'best_accuracy': self._best_accuracy, 'optimizer': self._optimizer.state_dict(), 'step': self._step, 'scheduler': self._scheduler.state_dict(), 'memory_pool': self.memory_pool, }) console_content = 'Best at epoch {}, test accuracy is {}'.format(self._best_epoch, self._best_accuracy) print_to_console(console_content) # rename log file, stats files and model os.rename(self._logfile, self._logfile.replace('.txt', '-{}_{}_{}_{:.4f}.txt'.format( self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy)))
def single_epoch_training(self, epoch, log_iter=True, log_freq=100): if epoch >= self.T_k: stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch+1) stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num' print_to_logfile(stats_log_path, stats_log_header, init=True, end='\n') for it, (x, y, indices) in enumerate(self._train_loader): s = time.time() x = x.cuda() y = y.cuda() self._optimizer.zero_grad() logits = self._net(x) losses, ce_loss = std_loss(logits, y, indices, self.T_k, epoch, self.memory_pool, eps=self._config['eps']) loss = losses.mean() self.memory_pool.update(indices=indices, losses=ce_loss.detach().data.cpu(), scores=F.softmax(logits, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) train_accuracy = accuracy(logits, y, topk=(1,)) self._train_loss.update(loss.item(), x.size(0)) self._train_accuracy.update(train_accuracy[0], x.size(0)) loss.backward() self._optimizer.step() e = time.time() self._epoch_train_time.update(e-s, 1) if (log_iter and (it+1) % log_freq == 0) or (it+1 == len(self._train_loader)): console_content = 'Epoch:[{0:03d}/{1:03d}] Iter:[{2:04d}/{3:04d}] ' \ 'Train Accuracy :[{4:6.2f}] Loss:[{5:4.4f}] ' \ 'Iter Runtime:[{6:6.2f}]'.format(epoch + 1, self._epochs, it + 1, len(self._train_loader), self._train_accuracy.avg, self._train_loss.avg, self._epoch_train_time.avg) print_to_console(console_content)
def single_epoch_training(self, epoch, log_iter=True, log_freq=200): if epoch >= self.T_k: stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num' print_to_logfile(stats_log_path1, stats_log_header, init=True, end='\n') print_to_logfile(stats_log_path2, stats_log_header, init=True, end='\n') for it, (x, y, indices) in enumerate(self._train_loader): s = time.time() x = x.cuda() y = y.cuda() self._optimizer1.zero_grad() self._optimizer2.zero_grad() logits1 = self._net1(x) logits2 = self._net2(x) losses1, ce_loss1, losses2, ce_loss2 = \ cot_std_loss(logits1, logits2, y, indices, self.T_k, epoch, self.memory_pool1, self.memory_pool1, eps=self._config['eps']) loss1 = losses1.mean() loss2 = losses2.mean() self.memory_pool1.update(indices=indices, losses=ce_loss1.detach().data.cpu(), scores=F.softmax( logits1, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) self.memory_pool1.update(indices=indices, losses=ce_loss2.detach().data.cpu(), scores=F.softmax( logits2, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) train_accuracy1 = accuracy(logits1, y, topk=(1, )) train_accuracy2 = accuracy(logits2, y, topk=(1, )) self._train_loss1.update(loss1.item(), losses1.size(0)) self._train_loss2.update(loss2.item(), losses1.size(0)) self._train_accuracy1.update(train_accuracy1[0], x.size(0)) self._train_accuracy2.update(train_accuracy2[0], x.size(0)) loss1.backward() loss2.backward() self._optimizer1.step() self._optimizer2.step() e = time.time() self._epoch_train_time.update(e - s, 1) if (log_iter and (it + 1) % log_freq == 0) or (it + 1 == len( self._train_loader)): console_content = 'Epoch:[{:03d}/{:03d}] Iter:[{:04d}/{:04d}] ' \ 'Train Accuracy1 :[{:6.2f}] Train Accuracy2 :[{:6.2f}] ' \ 'Loss1:[{:4.4f}] Loss2:[{:4.4f}] ' \ 'Iter Runtime:[{:6.2f}]'.format(epoch + 1, self._epochs, it + 1, len(self._train_loader), self._train_accuracy1.avg, self._train_accuracy2.avg, self._train_loss1.avg, self._train_loss2.avg, self._epoch_train_time.avg) print_to_console(console_content)
def train(self): console_header = 'Epoch\tTrain_Loss1\tTrain_Loss2\tTrain_Accuracy1\tTrain_Accuracy2\t' \ 'Test_Accuracy1\tTest_Accuracy2\tEpoch_Runtime\tLearning_Rate1\tLearning_Rate2' print_to_console(console_header) print_to_logfile(self._logfile, console_header, init=True) for t in range(self._start_epoch, self._epochs): epoch_start = time.time() self._scheduler1.step(epoch=t) self._scheduler2.step(epoch=t) # reset average meters self._train_loss1.reset() self._train_loss2.reset() self._train_accuracy1.reset() self._train_accuracy2.reset() self._net1.train(True) self._net2.train(True) self.single_epoch_training(t) test_accuracy1 = evaluate(self._test_loader, self._net1) test_accuracy2 = evaluate(self._test_loader, self._net2) lr1 = get_lr_from_optimizer(self._optimizer1) lr2 = get_lr_from_optimizer(self._optimizer2) if test_accuracy1 > self._best_accuracy1: self._best_accuracy1 = test_accuracy1 self._best_epoch1 = t + 1 torch.save( self._net1.state_dict(), 'model/net1_step{}_best_epoch.pth'.format(self._step)) if test_accuracy2 > self._best_accuracy2: self._best_accuracy2 = test_accuracy2 self._best_epoch2 = t + 1 torch.save( self._net2.state_dict(), 'model/net2_step{}_best_epoch.pth'.format(self._step)) epoch_end = time.time() single_epoch_runtime = epoch_end - epoch_start # Logging console_content = '{:05d}\t{:10.4f}\t{:10.4f}\t{:14.4f}\t{:14.4f}\t' \ '{:13.4f}\t{:13.4f}\t{:13.2f}\t' \ '{:13.1e}\t{:13.1e}'.format(t + 1, self._train_loss1.avg, self._train_loss2.avg, self._train_accuracy1.avg, self._train_accuracy2.avg, test_accuracy1, test_accuracy2, single_epoch_runtime, lr1, lr2) print_to_console(console_content) print_to_logfile(self._logfile, console_content, init=False) # save checkpoint save_checkpoint({ 'epoch': t + 1, 'state_dict1': self._net1.state_dict(), 'state_dict2': self._net2.state_dict(), 'best_epoch1': self._best_epoch1, 'best_epoch2': self._best_epoch2, 'best_accuracy1': self._best_accuracy1, 'best_accuracy2': self._best_accuracy2, 'optimizer1': self._optimizer1.state_dict(), 'optimizer2': self._optimizer2.state_dict(), 'step': self._step, 'scheduler1': self._scheduler1.state_dict(), 'scheduler2': self._scheduler2.state_dict(), 'memory_pool1': self.memory_pool1, 'memory_pool2': self.memory_pool2, }) console_content = 'Net1: Best at epoch {}, test accuracy is {}'.format( self._best_epoch1, self._best_accuracy1) print_to_console(console_content) console_content = 'Net2: Best at epoch {}, test accuracy is {}'.format( self._best_epoch2, self._best_accuracy2) print_to_console(console_content) # rename log file os.rename( self._logfile, self._logfile.replace( '.txt', '-{}_{}_{}_{:.4f}_{:.4f}.txt'.format( self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy1, self._best_accuracy2)))
def cot_std_loss(logits1, logits2, labels, indices, T_k, epoch, memory_pool1, memory_pool2, eps=0.1): ce_losses1 = label_smoothing_cross_entropy(logits1, labels, epsilon=eps, reduction='none') # (N,) ce_losses2 = label_smoothing_cross_entropy(logits2, labels, epsilon=eps, reduction='none') # (N,) # in the first T_k epochs, train with the entire training set if epoch < T_k: # print('using naive CE', end=' <--- ') return ce_losses1, ce_losses1, ce_losses2, ce_losses2 # after T_k epochs, start dividing training set into clean / uncertain / irrelevant ind_loss_sorted1 = torch.argsort(ce_losses1.data) ind_loss_sorted2 = torch.argsort(ce_losses2.data) num_remember1 = torch.nonzero(ce_losses1 < ce_losses1.mean()).shape[0] num_remember2 = torch.nonzero(ce_losses2 < ce_losses2.mean()).shape[0] # print(' ---> {:2d}, {:2d}'.format(num_remember1, num_remember1), end=', ') stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) print_to_logfile(stats_log_path1, '{:03d}'.format(num_remember1), init=False, end=',') print_to_logfile(stats_log_path2, '{:03d}'.format(num_remember2), init=False, end=',') ind_clean1 = ind_loss_sorted1[:num_remember1] ind_clean2 = ind_loss_sorted2[:num_remember2] ind_forget1 = ind_loss_sorted1[num_remember1:] ind_forget2 = ind_loss_sorted2[num_remember2:] logits_clean1 = logits1[ind_clean2] logits_clean2 = logits2[ind_clean1] labels_clean1 = labels[ind_clean2] labels_clean2 = labels[ind_clean1] logits_final1 = logits_clean1 logits_final2 = logits_clean2 labels_final1 = labels_clean1 labels_final2 = labels_clean2 if ind_forget1.shape[0] > 1: # for samples with high loss # high loss, high std --> mislabeling # high loss, low std --> irrelevant category # indices_forget1 = indices[ind_forget1] logits_forget1 = logits1[ind_forget1] pred_distribution1 = F.softmax(logits_forget1, dim=1) batch_std1 = pred_distribution1.std(dim=1) flag1 = F.softmax(logits_clean1, dim=1).std(dim=1).mean().item() # print('{:.5f}'.format(flag), end='*****') batch_std_sorted1, ind_std_sorted1 = torch.sort(batch_std1.data, descending=True) ind_split1 = split_set(batch_std_sorted1, flag1) if ind_split1 is None: ind_split1 = -1 # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ') # uncertain could be either mislabeled or hard example ind_uncertain1 = ind_std_sorted1[:(ind_split1 + 1)] # print('{:2d}/{:2d}'.format(ind_split1 + 1, logits1.shape[0] - num_remember1), end=' <--- ') print_to_logfile( stats_log_path1, '{:03d},{:03d}'.format( ind_split1 + 1, logits1.shape[0] - num_remember1 - ind_split1 - 1)) ind_mislabeled1 = ind_forget1[ind_uncertain1] logits_mislabeled2 = logits2[ind_mislabeled1] indices_mislabeled2 = indices[ind_mislabeled1] labels_mislabeled2 = memory_pool2.most_prob_labels[ indices_mislabeled2].to(logits_mislabeled2.device) logits_final2 = torch.cat((logits_final2, logits_mislabeled2), dim=0) labels_final2 = torch.cat((labels_final2, labels_mislabeled2), dim=0) if ind_forget2.shape[0] > 1: # for samples with high loss # high loss, high std --> mislabeling # high loss, low std --> irrelevant category # indices_forget2 = indices[ind_forget2] logits_forget2 = logits2[ind_forget2] pred_distribution2 = F.softmax(logits_forget2, dim=1) batch_std2 = pred_distribution2.std(dim=1) flag2 = F.softmax(logits_clean2, dim=1).std(dim=1).mean().item() # print('{:.5f}'.format(flag), end='*****') batch_std_sorted2, ind_std_sorted2 = torch.sort(batch_std2.data, descending=True) ind_split2 = split_set(batch_std_sorted2, flag2) if ind_split2 is None: ind_split2 = -1 # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ') # uncertain could be either mislabeled or hard example ind_uncertain2 = ind_std_sorted2[:(ind_split2 + 1)] # print('{:2d}/{:2d}'.format(ind_split2 + 1, logits2.shape[0] - num_remember2), end=' <--- ') print_to_logfile( stats_log_path2, '{:03d},{:03d}'.format( ind_split2 + 1, logits2.shape[0] - num_remember2 - ind_split2 - 1)) ind_mislabeled2 = ind_forget2[ind_uncertain2] logits_mislabeled1 = logits1[ind_mislabeled2] indices_mislabeled1 = indices[ind_mislabeled2] labels_mislabeled1 = memory_pool1.most_prob_labels[ indices_mislabeled1].to(logits_mislabeled1.device) logits_final1 = torch.cat((logits_final1, logits_mislabeled1), dim=0) labels_final1 = torch.cat((labels_final1, labels_mislabeled1), dim=0) else: # print('{:2d}/{:2d}, {:2d}/{:2d}'.format(0, logits1.shape[0] - num_remember1, # 0, logits2.shape[0] - num_remember2), end=' <--- ') print_to_logfile( stats_log_path1, '{:03d},{:03d}'.format(0, logits1.shape[0] - num_remember1)) print_to_logfile( stats_log_path2, '{:03d},{:03d}'.format(0, logits2.shape[0] - num_remember2)) losses1 = label_smoothing_cross_entropy(logits_final1, labels_final1, epsilon=eps, reduction='none') losses2 = label_smoothing_cross_entropy(logits_final2, labels_final2, epsilon=eps, reduction='none') return losses1, ce_losses1, losses2, ce_losses2
def std_loss(logits, labels, indices, T_k, epoch, memory_pool, eps=0.1): ce_losses = label_smoothing_cross_entropy(logits, labels, epsilon=eps, reduction='none') # in the first T_k epochs, train with the entire training set if epoch < T_k: # print('using naive CE', end=' <--- ') return ce_losses, ce_losses # after T_k epochs, start dividing training set into clean / uncertain / irrelevant ind_loss_sorted = torch.argsort(ce_losses.data) num_remember = torch.nonzero(ce_losses < ce_losses.mean()).shape[0] # print(' ---> {:2d}'.format(num_remember), end=', ') stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch + 1) print_to_logfile(stats_log_path, '{:03d}'.format(num_remember), init=False, end=',') ind_clean = ind_loss_sorted[:num_remember] ind_forget = ind_loss_sorted[num_remember:] logits_clean = logits[ind_clean] labels_clean = labels[ind_clean] if ind_forget.shape[0] > 1: # for samples with high loss # high loss, high std --> mislabeling # high loss, low std --> irrelevant category indices_forget = indices[ind_forget] logits_forget = logits[ind_forget] pred_distribution = F.softmax(logits_forget, dim=1) batch_std = pred_distribution.std(dim=1) flag = F.softmax(logits_clean, dim=1).std(dim=1).mean().item() # print('{:.5f}'.format(flag), end='*****') batch_std_sorted, ind_std_sorted = torch.sort(batch_std.data, descending=True) ind_split = split_set(batch_std_sorted, flag) if ind_split is None: ind_split = -1 # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ') # uncertain could be either mislabeled or hard example ind_uncertain = ind_std_sorted[:(ind_split + 1)] # print('{:2d}/{:2d}'.format(ind_split+1, logits.shape[0] - num_remember), end=' <--- ') print_to_logfile(stats_log_path, '{:03d},{:03d}'.format( ind_split + 1, logits.shape[0] - num_remember - ind_split - 1), init=False, end='\n') logits_mislabeled = logits_forget[ind_uncertain] indices_mislabeled = indices_forget[ind_uncertain] labels_mislabeled = memory_pool.most_prob_labels[ indices_mislabeled].to(logits_mislabeled.device) logits_final = torch.cat((logits_clean, logits_mislabeled), dim=0) labels_final = torch.cat((labels_clean, labels_mislabeled), dim=0) else: # print('{:2d}/{:2d}'.format(0, logits.shape[0] - num_remember), end=' <--- ') print_to_logfile(stats_log_path, '{:03d},{:03d}'.format(0, logits.shape[0] - num_remember), init=False, end='\n') logits_final = logits_clean labels_final = labels_clean std_losses = label_smoothing_cross_entropy(logits_final, labels_final, epsilon=eps, reduction='none') return std_losses, ce_losses