def evaluate(dataloader, model, topk=(1,)): """ :param dataloader: :param model: :param topk: [tuple] output the top topk accuracy :return: [list[float]] topk accuracy """ model.eval() test_accuracy = AverageMeter() test_accuracy.reset() with torch.no_grad(): for x, y, _ in dataloader: x = x.cuda() y = y.cuda() logits = model(x) acc = accuracy(logits, y, topk) test_accuracy.update(acc[0], x.size(0)) return test_accuracy.avg
def main(seed=25): seed_everything(25) device = torch.device('cuda:0') # arguments args = Args().parse() n_class = args.n_class img_path_train = args.img_path_train mask_path_train = args.mask_path_train img_path_val = args.img_path_val mask_path_val = args.mask_path_val model_path = os.path.join(args.model_path, args.task_name) # save model log_path = args.log_path output_path = args.output_path if not os.path.exists(model_path): os.makedirs(model_path) if not os.path.exists(log_path): os.makedirs(log_path) if not os.path.exists(output_path): os.makedirs(output_path) task_name = args.task_name print(task_name) ################################### evaluation = args.evaluation test = evaluation and False print("evaluation:", evaluation, "test:", test) ################################### print("preparing datasets and dataloaders......") batch_size = args.batch_size num_workers = args.num_workers config = args.config data_time = AverageMeter("DataTime", ':3.3f') batch_time = AverageMeter("BatchTime", ':3.3f') dataset_train = DoiDataset(img_path_train, config, train=True, root_mask=mask_path_train) dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) dataset_val = DoiDataset(img_path_val, config, train=True, root_mask=mask_path_val) dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=num_workers) ################################### print("creating models......") model = DoiNet(n_class, config['min_descriptor'] + 6, 4) model = create_model_load_weights(model, evaluation=False, ckpt_path=args.ckpt_path) model.to(device) ################################### num_epochs = args.epochs learning_rate = args.lr optimizer = get_optimizer(model, learning_rate=learning_rate) scheduler = LR_Scheduler(args.scheduler, learning_rate, num_epochs, len(dataloader_train)) ################################## criterion_node = nn.CrossEntropyLoss() criterion_edge = nn.BCELoss() alpha = args.alpha writer = SummaryWriter(log_dir=log_path + task_name) f_log = open(log_path + task_name + ".log", 'w') ####################################### trainer = Trainer(criterion_node, criterion_edge, optimizer, n_class, device, alpha=alpha) evaluator = Evaluator(n_class, device) best_pred = 0.0 print("start training......") log = task_name + '\n' for k, v in args.__dict__.items(): log += str(k) + ' = ' + str(v) + '\n' print(log) f_log.write(log) f_log.flush() for epoch in range(num_epochs): optimizer.zero_grad() tbar = tqdm(dataloader_train) train_loss = 0 train_loss_edge = 0 train_loss_node = 0 start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) if evaluation: # evaluation pattern: no training break scheduler(optimizer, i_batch, epoch, best_pred) loss, loss_node, loss_edge = trainer.train(sample, model) train_loss += loss.item() train_loss_node += loss_node.item() train_loss_edge += loss_edge.item() train_scores_node, train_scores_edge = trainer.get_scores() batch_time.update(time.time() - start_time) start_time = time.time() if i_batch % 2 == 0: tbar.set_description( 'Train loss: %.4f (loss_node=%.4f loss_edge=%.4f); F1 node: %.4f F1 edge: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), train_loss_node / (i_batch + 1), train_loss_edge / (i_batch + 1), train_scores_node["macro_f1"], train_scores_edge["macro_f1"], data_time.avg, batch_time.avg)) trainer.reset_metrics() data_time.reset() batch_time.reset() if epoch % 1 == 0: with torch.no_grad(): model.eval() print("evaluating...") tbar = tqdm(dataloader_val) start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) pred_node, pred_edge = evaluator.eval(sample, model) val_scores_node, val_scores_edge = evaluator.get_scores() batch_time.update(time.time() - start_time) tbar.set_description( 'F1 node: %.4f F1 edge: %.4f; data time: %.2f; batch time: %.2f' % (val_scores_node["macro_f1"], val_scores_edge["macro_f1"], data_time.avg, batch_time.avg)) start_time = time.time() data_time.reset() batch_time.reset() val_scores_node, val_scores_node = evaluator.get_scores() evaluator.reset_metrics() best_pred = save_model(model, model_path, val_scores_node, val_scores_edge, alpha, task_name, epoch, best_pred) write_log(f_log, train_scores_node, train_scores_edge, val_scores_node, val_scores_edge, epoch, num_epochs) write_summaryWriter(writer, train_loss / len(dataloader_train), optimizer, train_scores_node, train_scores_edge, val_scores_node, val_scores_edge, epoch) f_log.close()
def main(): opt = TrainOptions().parse() iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if opt.continue_train: try: start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 print '===========training options==========' #print opt ###def model for train model = P3dModel() model.initialize(opt) ### def train continue or fineture from pretrained model if opt.continue_train: model.load(fineture=False, which_epoch=start_epoch - 1, pretrain='') else: if opt.modality == 'RGB': pretrained_file = 'p3d_rgb_199.checkpoint.pth.tar' elif opt.modality == 'Flow': pretrained_file = 'p3d_flow_199.checkpoint.pth.tar' model.load(fineture=True, pretrain=pretrained_file) print '%s is useing' % (model.name()) #def vis Visual = Visualizer(opt) ## dummy_input = torch.rand(1, 3, 16, 224, 224).cuda() Visual.tbx_write_net(model.model, dummy_input) ### def all data loader train_data_loader, val_data_loader, dataset_size, _ = CreateDataLoader( opt, model) print '#training images = %d' % (dataset_size) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) cudnn.benchmark = True #def metrics top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() total_steps = (start_epoch - 1) * dataset_size + epoch_iter print_delta = total_steps % opt.print_freq update_size = opt.larger_batch_size // opt.batch_size update_num = 0 model.module.optimizer.zero_grad() for epoch in range(start_epoch, opt.epochs + 1): if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size top1.reset() top5.reset() losses.reset() model.train() for i, data in enumerate(train_data_loader, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batch_size epoch_iter += opt.batch_size #print data['data'].shape #print data['label'].shape input = Variable(data['data'].cuda()) label = Variable(data['label'].cuda()) ############## Forward Pass ###################### pred, loss = model( input, label, ) #need mean or not/ # loss=torch.mean(loss) pt1, pt5, _ = accuracy(pred.data, data['label'].cuda(), topk=(1, 5)) top1.update(pt1.item(), input.size(0)) top5.update(pt5.item(), input.size(0)) losses.update(loss.item(), input.size(0)) ############### Backward Pass #################### # update model weights loss.backward() update_num += 1 if update_num == update_size: model.module.optimizer.step() model.module.optimizer.zero_grad() update_num = 0 ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = { 'train_loss': losses.get(), 'top1': top1.get(), 'top5': top5.get() } t = (time.time() - iter_start_time) / opt.batch_size Visual.print_current_errors(epoch, epoch_iter, errors, t) Visual.tbx_write_errors(errors, total_steps, 'Train/loss') top1.reset() top5.reset() losses.reset() if epoch_iter >= dataset_size: break ### save model for this epoch ##valid here top1.reset() top5.reset() losses.reset() v__start_time = time.time() this_iter = 0 model.eval() with torch.no_grad(): for i, data in enumerate(val_data_loader): input = Variable(data['data'].cuda()) label = Variable(data['label'].cuda()) pred, loss = model( input, label, ) pt1, pt5, _ = accuracy(pred.data, data['label'].cuda(), topk=(1, 5)) top1.update(pt1.item(), input.size(0)) top5.update(pt5.item(), input.size(0)) losses.update(loss.item(), input.size(0)) this_iter = this_iter + opt.batch_size errors = { 'valid_loss': losses.get(), 'top1': top1.get(), 'top5': top5.get() } t = (time.time() - v__start_time) / opt.batch_size Visual.print_current_errors(epoch, this_iter, errors, t) Visual.tbx_write_errors(errors, total_steps, 'Valid/loss') top1.reset() top5.reset() losses.reset() if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.module.save(epoch) np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') ### linearly decay learning rate after certain iterations if epoch % opt.lr_decay_epoch == 0: model.module.update_learning_rate()
def main(): opt = TestOptions().parse(False) channel = 3 if opt.modality == 'RGB': channel = 3 data_length = 16 iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) ###def model for train model = P3dModel() model.initialize(opt) num_classes = model.num_classes which_epoch = start_epoch - 1 if opt.epoch_num > 0: which_epoch = opt.epoch_num model.load(fineture=False, which_epoch=which_epoch, pretrain='') print '%s is useing' % (model.name()) ### def all data loader test_data_loader, dataset_size = CreateTestLoader(opt, model) print '#testing images = %d' % (dataset_size) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) cudnn.benchmark = True #def metrics top1 = AverageMeter() top5 = AverageMeter() class_top1 = [AverageMeter() for i in range(num_classes)] class_top5 = [AverageMeter() for i in range(num_classes)] mix_m = np.zeros((num_classes, num_classes), np.float32) top1.reset() top5.reset() for i in range(num_classes): class_top1[i].reset() class_top5[i].reset() v__start_time = time.time() this_iter = 0 model.eval() with torch.no_grad(): for i, data in enumerate(test_data_loader): print '%d video isprocessing' % (i) input = Variable(data['data'].cuda()) label = Variable(data['label'].cuda()) sizes = input.size() assert sizes[ 2] == opt.crop_num * opt.num_segments * data_length, 'shape error' input = input.view(sizes[0], sizes[1], opt.crop_num * opt.num_segments, data_length, sizes[3], sizes[4]) input = input.permute(0, 2, 1, 3, 4, 5) input = input.view(sizes[0] * opt.crop_num * opt.num_segments, sizes[1], data_length, sizes[3], sizes[4]) pred = model.module.inference(input) pred = pred.view(opt.batch_size, opt.crop_num * opt.num_segments, -1) new_pred = torch.sum(pred.data, 1, False) pt1, pt5, acc_v = accuracy(new_pred, data['label'].cuda(), topk=(1, 5)) top1.update(pt1.item(), opt.batch_size) top5.update(pt5.item(), opt.batch_size) d = data['label'].numpy() assert d.shape[0] == 1, 'only support batch size ==1' key = d[0] class_top1[key].update(pt1.item(), opt.batch_size) class_top5[key].update(pt5.item(), opt.batch_size) cc = acc_v[0] mix_m[key, cc] += 1 this_iter = this_iter + opt.batch_size #if i>100: # break t = (time.time() - v__start_time) / opt.batch_size print '%f times test result top5: %f, top1: %f' % (t, top5.get(), top1.get()) total_loss_file = '%s/total_loss.txt' % (opt.checkpoints_dir) message = '%s_nseg_%d_ncrop_%d_nepoch_%d : ' % ( opt.name, opt.num_segments, opt.crop_num, opt.epoch_num) with open(total_loss_file, "a") as tlf: message += '%f times test result top5: %f, top1: %f \n' % ( t, top5.get(), top1.get()) tlf.write('%s ' % message) print '==============each class accuracy=========================' for i in range(num_classes): print 'class %d test result top5: %f, top1: %f' % ( i, class_top5[i].get(), class_top1[i].get()) mix_m[i, :] /= class_top1[i].count plot_mix(mix_m, opt) top1.reset() top5.reset() for i in range(num_classes): class_top1[i].reset() class_top5[i].reset()
class Model: def __init__(self, args): # common args self.args = args self.best_miou = -1.0 self.dataset_name = args.dataset_name self.debug = args.debug self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu:0") self.dir_checkpoints = f"{args.dir_root}/checkpoints/{args.experim_name}" self.experim_name = args.experim_name self.ignore_index = args.ignore_index self.init_n_pixels = args.n_init_pixels self.max_budget = args.max_budget self.n_classes = args.n_classes self.n_epochs = args.n_epochs self.n_pixels_by_us = args.n_pixels_by_us self.network_name = args.network_name self.nth_query = -1 self.stride_total = args.stride_total self.dataloader = get_dataloader(deepcopy(args), val=False, query=False, shuffle=True, batch_size=args.batch_size, n_workers=args.n_workers) self.dataloader_query = get_dataloader(deepcopy(args), val=False, query=True, shuffle=False, batch_size=1, n_workers=args.n_workers) self.dataloader_val = get_dataloader(deepcopy(args), val=True, query=False, shuffle=False, batch_size=1, n_workers=args.n_workers) self.model = get_model(args).to(self.device) self.lr_scheduler_type = args.lr_scheduler_type self.query_selector = QuerySelector(args, self.dataloader_query) self.vis = Visualiser(args.dataset_name) # for tracking stats self.running_loss, self.running_score = AverageMeter(), RunningScore( args.n_classes) # if active learning # if self.n_pixels_by_us > 0: # self.model_0_query = f"{self.dir_checkpoints}/0_query_{args.seed}.pt" def __call__(self): # fully-supervised model if self.n_pixels_by_us == 0: dir_checkpoints = f"{self.dir_checkpoints}/fully_sup" os.makedirs(f"{dir_checkpoints}", exist_ok=True) self.log_train, self.log_val = f"{dir_checkpoints}/log_train.txt", f"{dir_checkpoints}/log_val.txt" write_log(f"{self.log_train}", header=["epoch", "mIoU", "pixel_acc", "loss"]) write_log(f"{self.log_val}", header=["epoch", "mIoU", "pixel_acc"]) self._train() # active learning model else: n_stages = self.max_budget // self.n_pixels_by_us n_stages += 1 if self.init_n_pixels > 0 else 0 print("n_stages:", n_stages) for nth_query in range(n_stages): dir_checkpoints = f"{self.dir_checkpoints}/{nth_query}_query" os.makedirs(f"{dir_checkpoints}", exist_ok=True) self.log_train, self.log_val = f"{dir_checkpoints}/log_train.txt", f"{dir_checkpoints}/log_val.txt" write_log(f"{self.log_train}", header=["epoch", "mIoU", "pixel_acc", "loss"]) write_log(f"{self.log_val}", header=["epoch", "mIoU", "pixel_acc"]) self.nth_query = nth_query model = self._train() # select queries using the current model and label them. queries = self.query_selector(nth_query, model) self.dataloader.dataset.label_queries(queries, nth_query + 1) if nth_query == n_stages - 1: break # if nth_query == 0: # torch.save({"model": model.state_dict()}, self.model_0_query) return def _train_epoch(self, epoch, model, optimizer, lr_scheduler): if self.n_pixels_by_us != 0: print( f"training an epoch {epoch} of {self.nth_query}th query ({self.dataloader.dataset.n_pixels_total} labelled pixels)" ) fp = f"{self.dir_checkpoints}/{self.nth_query}_query/{epoch}_train.png" else: fp = f"{self.dir_checkpoints}/fully_sup/{epoch}_train.png" log = f"{self.log_train}" dataloader_iter, tbar = iter(self.dataloader), tqdm( range(len(self.dataloader))) model.train() for _ in tbar: dict_data = next(dataloader_iter) x, y = dict_data['x'].to(self.device), dict_data['y'].to( self.device) # if queries if self.n_pixels_by_us != 0: mask = dict_data['queries'].to(self.device, torch.bool) y.flatten()[~mask.flatten()] = self.ignore_index # forward pass dict_outputs = model(x) logits = dict_outputs["pred"] dict_losses = { "ce": F.cross_entropy(logits, y, ignore_index=self.ignore_index) } # backward pass loss = sum(dict_losses.values()) optimizer.zero_grad() loss.backward() optimizer.step() prob, pred = F.softmax(logits.detach(), dim=1), logits.argmax(dim=1) self.running_score.update(y.cpu().numpy(), pred.cpu().numpy()) self.running_loss.update(loss.detach().item()) scores = self.running_score.get_scores()[0] miou, pixel_acc = scores['Mean IoU'], scores['Pixel Acc'] # description description = f"({self.experim_name}) Epoch {epoch} | mIoU.: {miou:.3f} | pixel acc.: {pixel_acc:.3f} | " \ f"avg loss: {self.running_loss.avg:.3f}" for loss_k, loss_v in dict_losses.items(): description += f" | {loss_k}: {loss_v.detach().cpu().item():.3f}" tbar.set_description(description) if self.lr_scheduler_type == "Poly": lr_scheduler.step(epoch=epoch - 1) if self.debug: break if self.lr_scheduler_type == "MultiStepLR": lr_scheduler.step(epoch=epoch - 1) write_log( log, list_entities=[epoch, miou, pixel_acc, self.running_loss.avg]) self._reset_meters() ent, lc, ms, = [ self._query(prob, uc)[0].cpu() for uc in ["entropy", "least_confidence", "margin_sampling"] ] dict_tensors = { 'input': dict_data['x'][0].cpu(), 'target': dict_data['y'][0].cpu(), 'pred': pred[0].detach().cpu(), 'confidence': lc, 'margin': -ms, # minus sign is to draw smaller margin part brighter 'entropy': ent } self.vis(dict_tensors, fp=fp) return model, optimizer, lr_scheduler def _train(self): print(f"\n({self.experim_name}) training...\n") model = get_model(self.args).to(self.device) optimizer = get_optimizer(self.args, model) lr_scheduler = get_lr_scheduler(self.args, optimizer=optimizer, iters_per_epoch=len(self.dataloader)) for e in range(1, 1 + self.n_epochs): model, optimizer, lr_scheduler = self._train_epoch( e, model, optimizer, lr_scheduler) self._val(e, model) if self.debug: break self.best_miou = -1.0 return model @torch.no_grad() def _val(self, epoch, model): dataloader_iter, tbar = iter(self.dataloader_val), tqdm( range(len(self.dataloader_val))) model.eval() for _ in tbar: dict_data = next(dataloader_iter) x, y = dict_data['x'].to(self.device), dict_data['y'].to( self.device) if self.dataset_name == "voc": h, w = y.shape[1:] pad_h = ceil( h / self.stride_total) * self.stride_total - x.shape[2] pad_w = ceil( w / self.stride_total) * self.stride_total - x.shape[3] x = F.pad(x, pad=(0, pad_w, 0, pad_h), mode='reflect') dict_outputs = model(x) dict_outputs['pred'] = dict_outputs['pred'][:, :, :h, :w] else: dict_outputs = model(x) logits = dict_outputs['pred'] prob, pred = F.softmax(logits.detach(), dim=1), logits.argmax(dim=1) self.running_score.update(y.cpu().numpy(), pred.cpu().numpy()) scores = self.running_score.get_scores()[0] miou, pixel_acc = scores['Mean IoU'], scores['Pixel Acc'] tbar.set_description( f"mIoU: {miou:.3f} | pixel acc.: {pixel_acc:.3f}") if self.debug: break if miou > self.best_miou: state_dict = {"model": model.state_dict()} if self.n_pixels_by_us != 0: torch.save( state_dict, f"{self.dir_checkpoints}/{self.nth_query}_query/best_miou_model.pt" ) else: torch.save( state_dict, f"{self.dir_checkpoints}/fully_sup/best_miou_model.pt") print( f"best model has been saved" f"(epoch: {epoch} | prev. miou: {self.best_miou:.4f} => new miou: {miou:.4f})." ) self.best_miou = miou write_log(self.log_val, list_entities=[epoch, miou, pixel_acc]) print( f"\n{'=' * 100}" f"\nExperim name: {self.experim_name}" f"\nEpoch {epoch} | miou: {miou:.3f} | pixel_acc.: {pixel_acc:.3f}" f"\n{'=' * 100}\n") self._reset_meters() ent, lc, ms, = [ self._query(prob, uc)[0].cpu() for uc in ["entropy", "least_confidence", "margin_sampling"] ] dict_tensors = { 'input': dict_data['x'][0].cpu(), 'target': dict_data['y'][0].cpu(), 'pred': pred[0].detach().cpu(), 'confidence': lc, 'margin': -ms, # minus sign is to draw smaller margin part brighter 'entropy': ent } if self.n_pixels_by_us != 0: self.vis( dict_tensors, fp= f"{self.dir_checkpoints}/{self.nth_query}_query/{epoch}_val.png" ) else: self.vis(dict_tensors, fp=f"{self.dir_checkpoints}/fully_sup/{epoch}_val.png") return @staticmethod def _query(prob, query_strategy): # prob: b x n_classes x h x w if query_strategy == "least_confidence": query = 1.0 - prob.max(dim=1)[0] # b x h x w elif query_strategy == "margin_sampling": query = prob.topk(k=2, dim=1).values # b x k x h x w query = (query[:, 0, :, :] - query[:, 1, :, :]).abs() # b x h x w elif query_strategy == "entropy": query = (-prob * torch.log(prob)).sum(dim=1) # b x h x w elif query_strategy == "random": b, _, h, w = prob.shape query = torch.rand((b, h, w)) else: raise ValueError return query def _reset_meters(self): self.running_loss.reset() self.running_score.reset()
class CoteachingTrainer(object): def __init__(self, config): # Config self._config = config self._epochs = config['epochs'] self._step = config['step'] self._logfile = config['log'] self._n_classes = config['n_classes'] # Network if ',' in config['net']: net_name_1, net_name_2 = config['net'].split(',') else: net_name_1, net_name_2 = config['net'], config['net'] Net1, _ = make_network(net_name_1) Net2, _ = make_network(net_name_2) if self._step == 0: net1 = Net1(n_classes=self._n_classes, pretrained=True, use_two_step=False, fc_init='He') net2 = Net2(n_classes=self._n_classes, pretrained=True, use_two_step=False, fc_init='Xavier') elif self._step == 1: net1 = Net1(n_classes=self._n_classes, pretrained=True, use_two_step=True) net2 = Net2(n_classes=self._n_classes, pretrained=True, use_two_step=True) elif self._step == 2: net1 = Net1(n_classes=self._n_classes, pretrained=False, use_two_step=True) net2 = Net2(n_classes=self._n_classes, pretrained=False, use_two_step=True) else: raise AssertionError('step can only be 0, 1, 2') # Move network to cuda print('| Number of available GPUs : {} ({})'.format( torch.cuda.device_count(), os.environ["CUDA_VISIBLE_DEVICES"])) if torch.cuda.device_count() >= 1: self._net1 = nn.DataParallel(net1).cuda() self._net2 = nn.DataParallel(net2).cuda() else: raise AssertionError('CPU version is not implemented yet!') # Loss Criterion self.T_k = config['warmup_epochs'] if self._step == 1: self.T_k = self._epochs # Optimizer if self._step == 1: params_to_optimize1 = self._net1.module.fc.parameters() params_to_optimize2 = self._net2.module.fc.parameters() else: params_to_optimize1 = self._net1.parameters() params_to_optimize2 = self._net2.parameters() self._optimizer1 = make_optimizer(params_to_optimize1, lr=config['lr'] / 2, weight_decay=config['weight_decay'], opt='SGD') self._optimizer2 = make_optimizer(params_to_optimize2, lr=config['lr'], weight_decay=config['weight_decay'], opt='SGD') self._scheduler1 = optim.lr_scheduler.CosineAnnealingLR( self._optimizer1, T_max=self._epochs, eta_min=0) self._scheduler2 = optim.lr_scheduler.CosineAnnealingLR( self._optimizer2, T_max=self._epochs, eta_min=0) # metrics self._train_loss1 = AverageMeter() self._train_loss2 = AverageMeter() self._train_accuracy1 = AverageMeter() self._train_accuracy2 = AverageMeter() self._epoch_train_time = AverageMeter() # Dataloader train_transform = make_transform(phase='train', output_size=448) test_transform = make_transform(phase='test', output_size=448) train_data = IndexedImageFolder(os.path.join(config['data_base'], 'train'), transform=train_transform) test_data = IndexedImageFolder(os.path.join(config['data_base'], 'val'), transform=test_transform) self._train_loader = data.DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True) self._test_loader = data.DataLoader(test_data, batch_size=16, shuffle=False, num_workers=4, pin_memory=True) print('|-----------------------------------------------------') print('| Number of samples in train set : {}'.format(len(train_data))) print('| Number of samples in test set : {}'.format(len(test_data))) print('| Number of classes in train set : {}'.format( len(train_data.classes))) print('| Number of classes in test set : {}'.format( len(test_data.classes))) print('|-----------------------------------------------------') assert len(train_data.classes) == self._n_classes and \ len(test_data.classes) == self._n_classes, 'number of classes is wrong' # Resume or not if config['resume']: assert os.path.isfile( 'checkpoint.pth'), 'no checkpoint.pth exists!' print('---> loading checkpoint.pth <---') checkpoint = torch.load('checkpoint.pth') assert self._step == checkpoint[ 'step'], 'step in checkpoint does not match step in argument' self._start_epoch = checkpoint['epoch'] self._best_accuracy1 = checkpoint['best_accuracy1'] self._best_accuracy2 = checkpoint['best_accuracy2'] self._best_epoch1 = checkpoint['best_epoch1'] self._best_epoch2 = checkpoint['best_epoch2'] self._net1.load_state_dict(checkpoint['state_dict1']) self._net2.load_state_dict(checkpoint['state_dict2']) self._optimizer1.load_state_dict(checkpoint['optimizer1']) self._optimizer2.load_state_dict(checkpoint['optimizer2']) self._scheduler1.load_state_dict(checkpoint['scheduler1']) self._scheduler2.load_state_dict(checkpoint['scheduler2']) self.memory_pool1 = checkpoint['memory_pool1'] self.memory_pool2 = checkpoint['memory_pool2'] else: print('---> no checkpoint loaded <---') if self._step == 2: print('---> loading step1_best_epoch.pth <---') assert os.path.isfile('model/step1_best_epoch.pth') self._net1.load_state_dict( torch.load('model/net1_step1_best_epoch.pth')) self._net2.load_state_dict( torch.load('model/net2_step1_best_epoch.pth')) self._start_epoch = 0 self._best_accuracy1 = 0.0 self._best_accuracy2 = 0.0 self._best_epoch1 = None self._best_epoch2 = None self.memory_pool1 = Queue(n_samples=len(train_data), memory_length=config['memory_length']) self.memory_pool2 = Queue(n_samples=len(train_data), memory_length=config['memory_length']) self._scheduler1.last_epoch = self._start_epoch self._scheduler2.last_epoch = self._start_epoch def train(self): console_header = 'Epoch\tTrain_Loss1\tTrain_Loss2\tTrain_Accuracy1\tTrain_Accuracy2\t' \ 'Test_Accuracy1\tTest_Accuracy2\tEpoch_Runtime\tLearning_Rate1\tLearning_Rate2' print_to_console(console_header) print_to_logfile(self._logfile, console_header, init=True) for t in range(self._start_epoch, self._epochs): epoch_start = time.time() self._scheduler1.step(epoch=t) self._scheduler2.step(epoch=t) # reset average meters self._train_loss1.reset() self._train_loss2.reset() self._train_accuracy1.reset() self._train_accuracy2.reset() self._net1.train(True) self._net2.train(True) self.single_epoch_training(t) test_accuracy1 = evaluate(self._test_loader, self._net1) test_accuracy2 = evaluate(self._test_loader, self._net2) lr1 = get_lr_from_optimizer(self._optimizer1) lr2 = get_lr_from_optimizer(self._optimizer2) if test_accuracy1 > self._best_accuracy1: self._best_accuracy1 = test_accuracy1 self._best_epoch1 = t + 1 torch.save( self._net1.state_dict(), 'model/net1_step{}_best_epoch.pth'.format(self._step)) if test_accuracy2 > self._best_accuracy2: self._best_accuracy2 = test_accuracy2 self._best_epoch2 = t + 1 torch.save( self._net2.state_dict(), 'model/net2_step{}_best_epoch.pth'.format(self._step)) epoch_end = time.time() single_epoch_runtime = epoch_end - epoch_start # Logging console_content = '{:05d}\t{:10.4f}\t{:10.4f}\t{:14.4f}\t{:14.4f}\t' \ '{:13.4f}\t{:13.4f}\t{:13.2f}\t' \ '{:13.1e}\t{:13.1e}'.format(t + 1, self._train_loss1.avg, self._train_loss2.avg, self._train_accuracy1.avg, self._train_accuracy2.avg, test_accuracy1, test_accuracy2, single_epoch_runtime, lr1, lr2) print_to_console(console_content) print_to_logfile(self._logfile, console_content, init=False) # save checkpoint save_checkpoint({ 'epoch': t + 1, 'state_dict1': self._net1.state_dict(), 'state_dict2': self._net2.state_dict(), 'best_epoch1': self._best_epoch1, 'best_epoch2': self._best_epoch2, 'best_accuracy1': self._best_accuracy1, 'best_accuracy2': self._best_accuracy2, 'optimizer1': self._optimizer1.state_dict(), 'optimizer2': self._optimizer2.state_dict(), 'step': self._step, 'scheduler1': self._scheduler1.state_dict(), 'scheduler2': self._scheduler2.state_dict(), 'memory_pool1': self.memory_pool1, 'memory_pool2': self.memory_pool2, }) console_content = 'Net1: Best at epoch {}, test accuracy is {}'.format( self._best_epoch1, self._best_accuracy1) print_to_console(console_content) console_content = 'Net2: Best at epoch {}, test accuracy is {}'.format( self._best_epoch2, self._best_accuracy2) print_to_console(console_content) # rename log file os.rename( self._logfile, self._logfile.replace( '.txt', '-{}_{}_{}_{:.4f}_{:.4f}.txt'.format( self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy1, self._best_accuracy2))) def single_epoch_training(self, epoch, log_iter=True, log_freq=200): if epoch >= self.T_k: stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format( epoch + 1) stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num' print_to_logfile(stats_log_path1, stats_log_header, init=True, end='\n') print_to_logfile(stats_log_path2, stats_log_header, init=True, end='\n') for it, (x, y, indices) in enumerate(self._train_loader): s = time.time() x = x.cuda() y = y.cuda() self._optimizer1.zero_grad() self._optimizer2.zero_grad() logits1 = self._net1(x) logits2 = self._net2(x) losses1, ce_loss1, losses2, ce_loss2 = \ cot_std_loss(logits1, logits2, y, indices, self.T_k, epoch, self.memory_pool1, self.memory_pool1, eps=self._config['eps']) loss1 = losses1.mean() loss2 = losses2.mean() self.memory_pool1.update(indices=indices, losses=ce_loss1.detach().data.cpu(), scores=F.softmax( logits1, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) self.memory_pool1.update(indices=indices, losses=ce_loss2.detach().data.cpu(), scores=F.softmax( logits2, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) train_accuracy1 = accuracy(logits1, y, topk=(1, )) train_accuracy2 = accuracy(logits2, y, topk=(1, )) self._train_loss1.update(loss1.item(), losses1.size(0)) self._train_loss2.update(loss2.item(), losses1.size(0)) self._train_accuracy1.update(train_accuracy1[0], x.size(0)) self._train_accuracy2.update(train_accuracy2[0], x.size(0)) loss1.backward() loss2.backward() self._optimizer1.step() self._optimizer2.step() e = time.time() self._epoch_train_time.update(e - s, 1) if (log_iter and (it + 1) % log_freq == 0) or (it + 1 == len( self._train_loader)): console_content = 'Epoch:[{:03d}/{:03d}] Iter:[{:04d}/{:04d}] ' \ 'Train Accuracy1 :[{:6.2f}] Train Accuracy2 :[{:6.2f}] ' \ 'Loss1:[{:4.4f}] Loss2:[{:4.4f}] ' \ 'Iter Runtime:[{:6.2f}]'.format(epoch + 1, self._epochs, it + 1, len(self._train_loader), self._train_accuracy1.avg, self._train_accuracy2.avg, self._train_loss1.avg, self._train_loss2.avg, self._epoch_train_time.avg) print_to_console(console_content)
class Trainer: """Pipeline to train a NN model using a certain dataset, both specified by an YML config.""" @use_seed() def __init__(self, config_path, run_dir): self.config_path = coerce_to_path_and_check_exist(config_path) self.run_dir = coerce_to_path_and_create_dir(run_dir) self.logger = get_logger(self.run_dir, name="trainer") self.print_and_log_info( "Trainer initialisation: run directory is {}".format(run_dir)) shutil.copy(self.config_path, self.run_dir) self.print_and_log_info("Config {} copied to run directory".format( self.config_path)) with open(self.config_path) as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) if torch.cuda.is_available(): type_device = "cuda" nb_device = torch.cuda.device_count() # XXX: set to False when input image sizes are not fixed torch.backends.cudnn.benchmark = cfg["training"].get( "cudnn_benchmark", True) else: type_device = "cpu" nb_device = None self.device = torch.device(type_device) self.print_and_log_info("Using {} device, nb_device is {}".format( type_device, nb_device)) # Datasets and dataloaders self.dataset_kwargs = cfg["dataset"] self.dataset_name = self.dataset_kwargs.pop("name") train_dataset = get_dataset(self.dataset_name)("train", **self.dataset_kwargs) val_dataset = get_dataset(self.dataset_name)("val", **self.dataset_kwargs) self.restricted_labels = sorted( self.dataset_kwargs["restricted_labels"]) self.n_classes = len(self.restricted_labels) + 1 self.is_val_empty = len(val_dataset) == 0 self.print_and_log_info("Dataset {} instantiated with {}".format( self.dataset_name, self.dataset_kwargs)) self.print_and_log_info( "Found {} classes, {} train samples, {} val samples".format( self.n_classes, len(train_dataset), len(val_dataset))) self.batch_size = cfg["training"]["batch_size"] self.n_workers = cfg["training"]["n_workers"] self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=True) self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.n_workers) self.print_and_log_info( "Dataloaders instantiated with batch_size={} and n_workers={}". format(self.batch_size, self.n_workers)) self.n_batches = len(self.train_loader) self.n_iterations, self.n_epoches = cfg["training"].get( "n_iterations"), cfg["training"].get("n_epoches") assert not (self.n_iterations is not None and self.n_epoches is not None) if self.n_iterations is not None: self.n_epoches = max(self.n_iterations // self.n_batches, 1) else: self.n_iterations = self.n_epoches * len(self.train_loader) # Model self.model_kwargs = cfg["model"] self.model_name = self.model_kwargs.pop("name") model = get_model(self.model_name)(self.n_classes, **self.model_kwargs).to(self.device) self.model = torch.nn.DataParallel(model, device_ids=range( torch.cuda.device_count())) self.print_and_log_info("Using model {} with kwargs {}".format( self.model_name, self.model_kwargs)) self.print_and_log_info('Number of trainable parameters: {}'.format( f'{count_parameters(self.model):,}')) # Optimizer optimizer_params = cfg["training"]["optimizer"] or {} optimizer_name = optimizer_params.pop("name", None) self.optimizer = get_optimizer(optimizer_name)(model.parameters(), **optimizer_params) self.print_and_log_info("Using optimizer {} with kwargs {}".format( optimizer_name, optimizer_params)) # Scheduler scheduler_params = cfg["training"].get("scheduler", {}) or {} scheduler_name = scheduler_params.pop("name", None) self.scheduler_update_range = scheduler_params.pop( "update_range", "epoch") assert self.scheduler_update_range in ["epoch", "batch"] if scheduler_name == "multi_step" and isinstance( scheduler_params["milestones"][0], float): n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations scheduler_params["milestones"] = [ round(m * n_tot) for m in scheduler_params["milestones"] ] self.scheduler = get_scheduler(scheduler_name)(self.optimizer, **scheduler_params) self.cur_lr = -1 self.print_and_log_info("Using scheduler {} with parameters {}".format( scheduler_name, scheduler_params)) # Loss loss_name = cfg["training"]["loss"] self.criterion = get_loss(loss_name)() self.print_and_log_info("Using loss {}".format(self.criterion)) # Pretrained / Resume checkpoint_path = cfg["training"].get("pretrained") checkpoint_path_resume = cfg["training"].get("resume") assert not (checkpoint_path is not None and checkpoint_path_resume is not None) if checkpoint_path is not None: self.load_from_tag(checkpoint_path) elif checkpoint_path_resume is not None: self.load_from_tag(checkpoint_path_resume, resume=True) else: self.start_epoch, self.start_batch = 1, 1 # Train metrics train_iter_interval = cfg["training"].get( "train_stat_interval", self.n_epoches * self.n_batches // 200) self.train_stat_interval = train_iter_interval self.train_time = AverageMeter() self.train_loss = AverageMeter() self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE with open(self.train_metrics_path, mode="w") as f: f.write( "iteration\tepoch\tbatch\ttrain_loss\ttrain_time_per_img\n") # Val metrics val_iter_interval = cfg["training"].get( "val_stat_interval", self.n_epoches * self.n_batches // 100) self.val_stat_interval = val_iter_interval self.val_loss = AverageMeter() self.val_metrics = RunningMetrics(self.restricted_labels) self.val_current_score = None self.val_metrics_path = self.run_dir / VAL_METRICS_FILE with open(self.val_metrics_path, mode="w") as f: f.write("iteration\tepoch\tbatch\tval_loss\t" + "\t".join(self.val_metrics.names) + "\n") def print_and_log_info(self, string): print_info(string) self.logger.info(string) def load_from_tag(self, tag, resume=False): self.print_and_log_info("Loading model from run {}".format(tag)) path = coerce_to_path_and_check_exist(MODELS_PATH / tag / MODEL_FILE) checkpoint = torch.load(path, map_location=self.device) try: self.model.load_state_dict(checkpoint["model_state"]) except RuntimeError: state = safe_model_state_dict(checkpoint["model_state"]) self.model.module.load_state_dict(state) self.start_epoch, self.start_batch = 1, 1 if resume: self.start_epoch, self.start_batch = checkpoint[ "epoch"], checkpoint.get("batch", 0) + 1 self.optimizer.load_state_dict(checkpoint["optimizer_state"]) self.scheduler.load_state_dict(checkpoint["scheduler_state"]) self.cur_lr = self.scheduler.get_lr() self.print_and_log_info( "Checkpoint loaded at epoch {}, batch {}".format( self.start_epoch, self.start_batch - 1)) def _create_external_val_loader_and_monitor(self, dataset_name): val_dataset = get_dataset(dataset_name)(split="val", **self.dataset_kwargs) val_loader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.n_workers) self.print_and_log_info( "External {} validation dataset instantiated with kwargs {}: {} samples" .format(dataset_name, self.dataset_kwargs, len(val_dataset))) monitor = {} monitor["name"] = dataset_name monitor["loss"] = AverageMeter() monitor["metrics"] = RunningMetrics(val_dataset.restricted_labels, val_dataset.metric_labels) monitor["metrics_path"] = self.run_dir / "{}_metrics.tsv".format( dataset_name) with open(monitor["metrics_path"], mode="w") as f: f.write("iteration\tepoch\tbatch\t{}_loss\t".format(dataset_name) + "\t".join(monitor["metrics"].names) + "\n") return val_loader, monitor @property def score_name(self): return self.val_metrics.score_name def print_memory_usage(self, prefix): usage = {} for attr in [ "memory_allocated", "max_memory_allocated", "memory_cached", "max_memory_cached" ]: usage[attr] = getattr(torch.cuda, attr)() * 0.000001 self.print_and_log_info("{}:\t{}".format( prefix, " / ".join( ["{}: {:.0f}MiB".format(k, v) for k, v in usage.items()]))) @use_seed() def run(self): self.model.train() cur_iter = (self.start_epoch - 1) * self.n_batches + self.start_batch - 1 prev_train_stat_iter, prev_val_stat_iter = cur_iter, cur_iter for epoch in range(self.start_epoch, self.n_epoches + 1): batch_start = self.start_batch if epoch == self.start_epoch else 1 if self.scheduler_update_range == "epoch": if batch_start == 1: self.update_scheduler(epoch, batch=batch_start) for batch, (images, labels) in enumerate(self.train_loader, start=1): if batch < batch_start: continue cur_iter += 1 if cur_iter > self.n_iterations: break if self.scheduler_update_range == "batch": self.update_scheduler(epoch, batch=batch) self.single_train_batch_run(images, labels) if (cur_iter - prev_train_stat_iter) >= self.train_stat_interval: prev_train_stat_iter = cur_iter self.log_train_metrics(cur_iter, epoch, batch) if (cur_iter - prev_val_stat_iter) >= self.val_stat_interval: prev_val_stat_iter = cur_iter self.run_val() self.log_val_metrics(cur_iter, epoch, batch) self.save(epoch=epoch, batch=batch) self.print_and_log_info("Training run is over") def update_scheduler(self, epoch, batch): self.scheduler.step() lr = self.scheduler.get_lr() if lr != self.cur_lr: self.cur_lr = lr msg = PRINT_LR_UPD_FMT.format(epoch, self.n_epoches, batch, self.n_batches, lr) self.print_and_log_info(msg) def single_train_batch_run(self, images, labels): start_time = time.time() images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() loss = self.criterion(self.model(images), labels) loss.backward() self.optimizer.step() self.train_loss.update(loss.item()) self.train_time.update((time.time() - start_time) / self.batch_size) def log_train_metrics(self, cur_iter, epoch, batch): stat = PRINT_TRAIN_STAT_FMT.format(epoch, self.n_epoches, batch, self.n_batches, self.train_loss.avg, self.train_time.avg) self.print_and_log_info(stat) with open(self.train_metrics_path, mode="a") as f: f.write("{}\t{}\t{}\t{:.4f}\t{:.4f}\n".format( cur_iter, epoch, batch, self.train_loss.avg, self.train_time.avg)) self.train_loss.reset() self.train_time.reset() def run_val(self): self.model.eval() with torch.no_grad(): for images, labels in self.val_loader: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) loss = self.criterion(outputs, labels) pred = outputs.data.max(1)[1].cpu().numpy() if images.size() == labels.size(): gt = labels.data.max(1)[1].cpu().numpy() else: gt = labels.cpu().numpy() self.val_metrics.update(gt, pred) self.val_loss.update(loss.item()) self.model.train() def log_val_metrics(self, cur_iter, epoch, batch): stat = PRINT_VAL_STAT_FMT.format(epoch, self.n_epoches, batch, self.n_batches, self.val_loss.avg) self.print_and_log_info(stat) metrics = self.val_metrics.get() self.print_and_log_info( "Val metrics: " + ", ".join(["{} = {:.4f}".format(k, v) for k, v in metrics.items()])) with open(self.val_metrics_path, mode="a") as f: f.write("{}\t{}\t{}\t{:.4f}\t".format(cur_iter, epoch, batch, self.val_loss.avg) + "\t".join(map("{:.4f}".format, metrics.values())) + "\n") self.val_current_score = metrics[self.score_name] self.val_loss.reset() self.val_metrics.reset() def save(self, epoch, batch): state = { "epoch": epoch, "batch": batch, "model_name": self.model_name, "model_kwargs": self.model_kwargs, "model_state": self.model.state_dict(), "n_classes": self.n_classes, "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), "score": self.val_current_score, "train_resolution": self.dataset_kwargs["img_size"], "restricted_labels": self.dataset_kwargs["restricted_labels"], "normalize": self.dataset_kwargs["normalize"], } save_path = self.run_dir / MODEL_FILE torch.save(state, save_path) self.print_and_log_info("Model saved at {}".format(save_path))
class Trainer(object): def __init__(self, config): # Config self._config = config self._epochs = config['epochs'] self._step = config['step'] self._logfile = config['log'] self._n_classes = config['n_classes'] # Network Net, feature_dim = make_network(config['net']) if self._step == 0: net = Net(n_classes=self._n_classes, pretrained=True, use_two_step=False) elif self._step == 1: net = Net(n_classes=self._n_classes, pretrained=True, use_two_step=True) elif self._step == 2: net = Net(n_classes=self._n_classes, pretrained=False, use_two_step=True) else: raise AssertionError('step can only be 0, 1, 2') # Move network to cuda print('| Number of available GPUs : {} ({})'.format(torch.cuda.device_count(), os.environ["CUDA_VISIBLE_DEVICES"])) if torch.cuda.device_count() >= 1: self._net = nn.DataParallel(net).cuda() else: raise AssertionError('CPU version is not implemented yet!') # Loss Criterion self.T_k = config['warmup_epochs'] if self._step == 1: self.T_k = self._epochs # Optimizer if self._step == 1: params_to_optimize = self._net.module.fc.parameters() else: params_to_optimize = self._net.parameters() self._optimizer = make_optimizer(params_to_optimize, lr=config['lr'], weight_decay=config['weight_decay'], opt='SGD') self._scheduler = optim.lr_scheduler.CosineAnnealingLR(self._optimizer, T_max=self._epochs, eta_min=0) # metrics self._train_loss = AverageMeter() self._train_accuracy = AverageMeter() self._epoch_train_time = AverageMeter() # Dataloader train_transform = make_transform(phase='train', output_size=448) test_transform = make_transform(phase='test', output_size=448) train_data = IndexedImageFolder(os.path.join(config['data_base'], 'train'), transform=train_transform) test_data = IndexedImageFolder(os.path.join(config['data_base'], 'val'), transform=test_transform) self._train_loader = data.DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True) self._test_loader = data.DataLoader(test_data, batch_size=16, shuffle=False, num_workers=4, pin_memory=True) print('|-----------------------------------------------------') print('| Number of samples in train set : {}'.format(len(train_data))) print('| Number of samples in test set : {}'.format(len(test_data))) print('| Number of classes in train set : {}'.format(len(train_data.classes))) print('| Number of classes in test set : {}'.format(len(test_data.classes))) print('|-----------------------------------------------------') assert len(train_data.classes) == self._n_classes and \ len(test_data.classes) == self._n_classes, 'number of classes is wrong' # Resume or not if config['resume']: assert os.path.isfile('checkpoint.pth'), 'no checkpoint.pth exists!' print('---> loading checkpoint.pth <---') checkpoint = torch.load('checkpoint.pth') assert self._step == checkpoint['step'], 'step in checkpoint does not match step in argument' self._start_epoch = checkpoint['epoch'] self._best_accuracy = checkpoint['best_accuracy'] self._best_epoch = checkpoint['best_epoch'] self._net.load_state_dict(checkpoint['state_dict']) self._optimizer.load_state_dict(checkpoint['optimizer']) self._scheduler.load_state_dict(checkpoint['scheduler']) self.memory_pool = checkpoint['memory_pool'] else: print('---> no checkpoint loaded <---') if self._step == 2: print('---> loading step1_best_epoch.pth <---') assert os.path.isfile('model/step1_best_epoch.pth') self._net.load_state_dict(torch.load('model/step1_best_epoch.pth')) self._start_epoch = 0 self._best_accuracy = 0.0 self._best_epoch = None self.memory_pool = Queue(n_samples=len(train_data), memory_length=config['memory_length']) self._scheduler.last_epoch = self._start_epoch def train(self): console_header = 'Epoch\tTrain_Loss\tTrain_Accuracy\tTest_Accuracy\tEpoch_Runtime\tLearning_Rate' print_to_console(console_header) print_to_logfile(self._logfile, console_header, init=True) for t in range(self._start_epoch, self._epochs): epoch_start = time.time() self._scheduler.step(epoch=t) # reset average meters self._train_loss.reset() self._train_accuracy.reset() self._net.train(True) self.single_epoch_training(t) test_accuracy = evaluate(self._test_loader, self._net) lr = get_lr_from_optimizer(self._optimizer) if test_accuracy > self._best_accuracy: self._best_accuracy = test_accuracy self._best_epoch = t + 1 torch.save(self._net.state_dict(), 'model/step{}_best_epoch.pth'.format(self._step)) # print('*', end='') epoch_end = time.time() single_epoch_runtime = epoch_end - epoch_start # Logging console_content = '{:05d}\t{:10.4f}\t{:14.4f}\t{:13.4f}\t{:13.2f}\t{:13.1e}'.format( t + 1, self._train_loss.avg, self._train_accuracy.avg, test_accuracy, single_epoch_runtime, lr) print_to_console(console_content) print_to_logfile(self._logfile, console_content, init=False) # save checkpoint save_checkpoint({ 'epoch': t + 1, 'state_dict': self._net.state_dict(), 'best_epoch': self._best_epoch, 'best_accuracy': self._best_accuracy, 'optimizer': self._optimizer.state_dict(), 'step': self._step, 'scheduler': self._scheduler.state_dict(), 'memory_pool': self.memory_pool, }) console_content = 'Best at epoch {}, test accuracy is {}'.format(self._best_epoch, self._best_accuracy) print_to_console(console_content) # rename log file, stats files and model os.rename(self._logfile, self._logfile.replace('.txt', '-{}_{}_{}_{:.4f}.txt'.format( self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy))) def single_epoch_training(self, epoch, log_iter=True, log_freq=100): if epoch >= self.T_k: stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch+1) stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num' print_to_logfile(stats_log_path, stats_log_header, init=True, end='\n') for it, (x, y, indices) in enumerate(self._train_loader): s = time.time() x = x.cuda() y = y.cuda() self._optimizer.zero_grad() logits = self._net(x) losses, ce_loss = std_loss(logits, y, indices, self.T_k, epoch, self.memory_pool, eps=self._config['eps']) loss = losses.mean() self.memory_pool.update(indices=indices, losses=ce_loss.detach().data.cpu(), scores=F.softmax(logits, dim=1).detach().data.cpu(), labels=y.detach().data.cpu()) train_accuracy = accuracy(logits, y, topk=(1,)) self._train_loss.update(loss.item(), x.size(0)) self._train_accuracy.update(train_accuracy[0], x.size(0)) loss.backward() self._optimizer.step() e = time.time() self._epoch_train_time.update(e-s, 1) if (log_iter and (it+1) % log_freq == 0) or (it+1 == len(self._train_loader)): console_content = 'Epoch:[{0:03d}/{1:03d}] Iter:[{2:04d}/{3:04d}] ' \ 'Train Accuracy :[{4:6.2f}] Loss:[{5:4.4f}] ' \ 'Iter Runtime:[{6:6.2f}]'.format(epoch + 1, self._epochs, it + 1, len(self._train_loader), self._train_accuracy.avg, self._train_loss.avg, self._epoch_train_time.avg) print_to_console(console_content)
def main(cfg, distributed=True): if distributed: # DPP 1 dist.init_process_group('nccl') # DPP 2 local_rank = dist.get_rank() print(local_rank) torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) else: device = torch.device("cuda:0") local_rank = 0 ################################################### mode = cfg.mode n_class = cfg.n_class model_path = cfg.model_path # save model log_path = cfg.log_path output_path = cfg.output_path if local_rank == 0: if not os.path.exists(model_path): os.makedirs(model_path) if not os.path.exists(log_path): os.makedirs(log_path) if not os.path.exists(output_path): os.makedirs(output_path) task_name = cfg.task_name print(task_name) ################################### print("preparing datasets and dataloaders......") batch_size = cfg.batch_size sub_batch_size = cfg.sub_batch_size size_g = (cfg.size_g, cfg.size_g) size_p = (cfg.size_p, cfg.size_p) num_workers = cfg.num_workers trainset_cfg = cfg.trainset_cfg valset_cfg = cfg.valset_cfg data_time = AverageMeter("DataTime", ':3.3f') batch_time = AverageMeter("BatchTime", ':3.3f') transformer_train = TransformerSegGL(crop_size=cfg.crop_size) dataset_train = OralDatasetSeg( trainset_cfg["img_dir"], trainset_cfg["mask_dir"], trainset_cfg["meta_file"], label=trainset_cfg["label"], transform=transformer_train, ) if distributed: sampler_train = DistributedSampler(dataset_train, shuffle=True) dataloader_train = DataLoader(dataset_train, num_workers=num_workers, batch_size=batch_size, collate_fn=collateGL, sampler=sampler_train, pin_memory=True) else: dataloader_train = DataLoader(dataset_train, num_workers=num_workers, batch_size=batch_size, collate_fn=collateGL, shuffle=True, pin_memory=True) transformer_val = TransformerSegGLVal() dataset_val = OralDatasetSeg(valset_cfg["img_dir"], valset_cfg["mask_dir"], valset_cfg["meta_file"], label=valset_cfg["label"], transform=transformer_val) dataloader_val = DataLoader(dataset_val, num_workers=2, batch_size=batch_size, collate_fn=collateGL, shuffle=False, pin_memory=True) ################################### print("creating models......") path_g = cfg.path_g path_g2l = cfg.path_g2l path_l2g = cfg.path_l2g model = GLNet(n_class, cfg.encoder, **cfg.model_cfg) if mode == 3: global_fixed = GLNet(n_class, cfg.encoder, **cfg.model_cfg) else: global_fixed = None model, global_fixed = create_model_load_weights(model, global_fixed, device, mode=mode, distributed=distributed, local_rank=local_rank, evaluation=False, path_g=path_g, path_g2l=path_g2l, path_l2g=path_l2g) ################################### num_epochs = cfg.num_epochs learning_rate = cfg.lr optimizer = get_optimizer(model, mode, learning_rate=learning_rate) scheduler = LR_Scheduler(cfg.scheduler, learning_rate, num_epochs, len(dataloader_train)) ################################## if cfg.loss == "ce": criterion = nn.CrossEntropyLoss(reduction='mean') elif cfg.loss == "sce": criterion = SymmetricCrossEntropyLoss(alpha=cfg.alpha, beta=cfg.beta, num_classes=cfg.n_class) # criterion4 = NormalizedSymmetricCrossEntropyLoss(alpha=cfg.alpha, beta=cfg.beta, num_classes=cfg.n_class) elif cfg.loss == "focal": criterion = FocalLoss(gamma=3) elif cfg.loss == "ce-dice": criterion = nn.CrossEntropyLoss(reduction='mean') # criterion2 = ####################################### trainer = Trainer(criterion, optimizer, n_class, size_g, size_p, sub_batch_size, mode, cfg.lamb_fmreg) evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode) evaluation = cfg.evaluation val_vis = cfg.val_vis best_pred = 0.0 print("start training......") # log if local_rank == 0: f_log = open(os.path.join(log_path, ".log"), 'w') log = task_name + '\n' for k, v in cfg.__dict__.items(): log += str(k) + ' = ' + str(v) + '\n' f_log.write(log) f_log.flush() # writer if local_rank == 0: writer = SummaryWriter(log_dir=log_path) writer_info = {} for epoch in range(num_epochs): trainer.set_train(model) optimizer.zero_grad() tbar = tqdm(dataloader_train) train_loss = 0 start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) scheduler(optimizer, i_batch, epoch, best_pred) # loss = trainer.train(sample, model) loss = trainer.train(sample, model, global_fixed) train_loss += loss.item() score_train, score_train_global, score_train_local = trainer.get_scores( ) batch_time.update(time.time() - start_time) start_time = time.time() if i_batch % 20 == 0 and local_rank == 0: if mode == 1: tbar.set_description( 'Train loss: %.4f;global mIoU: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), score_train_global["iou_mean"], data_time.avg, batch_time.avg)) elif mode == 2: tbar.set_description( 'Train loss: %.4f;agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), score_train["iou_mean"], score_train_local["iou_mean"], data_time.avg, batch_time.avg)) else: tbar.set_description( 'Train loss: %.4f;agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), score_train["iou_mean"], score_train_global["iouu_mean"], score_train_local["iou_mean"], data_time.avg, batch_time.avg)) score_train, score_train_global, score_train_local = trainer.get_scores( ) trainer.reset_metrics() data_time.reset() batch_time.reset() if evaluation and epoch % 1 == 0 and local_rank == 0: with torch.no_grad(): model.eval() print("evaluating...") tbar = tqdm(dataloader_val) start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) predictions, predictions_global, predictions_local = evaluator.eval_test( sample, model, global_fixed) score_val, score_val_global, score_val_local = evaluator.get_scores( ) batch_time.update(time.time() - start_time) if i_batch % 20 == 0 and local_rank == 0: if mode == 1: tbar.set_description( 'global mIoU: %.4f; data time: %.2f; batch time: %.2f' % (score_val_global["iou_mean"], data_time.avg, batch_time.avg)) elif mode == 2: tbar.set_description( 'agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f' % (score_val["iou_mean"], score_val_local["iou_mean"], data_time.avg, batch_time.avg)) else: tbar.set_description( 'agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f' % (score_val["iou_mean"], score_val_global["iou_mean"], score_val_local["iou_mean"], data_time.avg, batch_time.avg)) if val_vis and i_batch == len( tbar) // 2: # val set result visualize mask_rgb = class_to_RGB(np.array(sample['mask'][1])) mask_rgb = ToTensor()(mask_rgb) writer_info.update(mask=mask_rgb, prediction_global=ToTensor()( class_to_RGB( predictions_global[1]))) if mode == 2 or mode == 3: writer.update(prediction=ToTensor()(class_to_RGB( predictions[1])), prediction_local=ToTensor()( class_to_RGB( predictions_local[1]))) start_time = time.time() data_time.reset() batch_time.reset() score_val, score_val_global, score_val_local = evaluator.get_scores( ) evaluator.reset_metrics() # save model best_pred = save_ckpt_model(model, cfg, score_val, score_val_global, best_pred, epoch) # log update_log( f_log, cfg, [score_train, score_train_global, score_train_local], [score_val, score_val_global, score_val_local], epoch) # writer if mode == 1: writer_info.update( loss=train_loss / len(tbar), lr=optimizer.param_groups[0]['lr'], mIOU={ "train": score_train_global["iou_mean"], "val": score_val_global["iou_mean"], }, global_mIOU={ "train": score_train_global["iou_mean"], "val": score_val_global["iou_mean"], }, mucosa_iou={ "train": score_train_global["iou"][2], "val": score_val_global["iou"][2], }, tumor_iou={ "train": score_train_global["iou"][3], "val": score_val_global["iou"][3], }, ) else: writer_info.update( loss=train_loss / len(tbar), lr=optimizer.param_groups[0]['lr'], mIOU={ "train": score_train["iou_mean"], "val": score_val["iou_mean"], }, global_mIOU={ "train": score_train_global["iou_mean"], "val": score_val_global["iou_mean"], }, local_mIOU={ "train": score_train_local["iou_mean"], "val": score_val_local["iou_mean"], }, mucosa_iou={ "train": score_train["iou"][2], "val": score_val["iou"][2], }, tumor_iou={ "train": score_train["iou"][3], "val": score_val["iou"][3], }, ) update_writer(writer, writer_info, epoch) if local_rank == 0: f_log.close()
def main(): NUM_POINT = 20000 opt = OptInit().initialize() opt.num_worker = 32 os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpuNum opt.printer.info('===> Creating dataloader ...') train_dataset = BigredDataset(root = opt.train_path, is_train=True, is_validation=False, is_test=False, num_channel=opt.num_channel, pre_transform=T.NormalizeScale() ) train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_worker) validation_dataset = BigredDataset(root = opt.train_path, is_train=False, is_validation=True, is_test=False, num_channel=opt.num_channel, pre_transform=T.NormalizeScale() ) validation_loader = DenseDataLoader(validation_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_worker) opt.printer.info('===> computing Labelweight ...') labelweights = np.zeros(2) labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3)) labelweights = labelweights.astype(np.float32) labelweights = labelweights / np.sum(labelweights) labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) weights = torch.Tensor(labelweights).cuda() print("labelweights", weights) opt.n_classes = train_loader.dataset.num_classes opt.printer.info('===> Loading the network ...') opt.best_value = 0 print("GPU:",opt.device) model = DenseDeepGCN(opt).to(opt.device) if opt.multi_gpus: model = DataParallel(DenseDeepGCN(opt)).to(device=opt.device) opt.printer.info('===> loading pre-trained ...') # model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase) opt.printer.info('===> Init the optimizer ...') criterion = torch.nn.CrossEntropyLoss(weight = weights).to(opt.device) # criterion_test = torch.nn.CrossEntropyLoss(weight = weights) if opt.optim.lower() == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) elif opt.optim.lower() == 'radam': optimizer = optim.RAdam(model.parameters(), lr=opt.lr) else: raise NotImplementedError('opt.optim is not supported') scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq, opt.lr_decay_rate) # optimizer, scheduler, opt.lr = load_pretrained_optimizer(opt.pretrained_model, optimizer, scheduler, opt.lr) opt.printer.info('===> Init Metric ...') opt.losses = AverageMeter() # opt.test_metric = miou opt.test_values = AverageMeter() opt.test_value = 0. opt.printer.info('===> start training ...') writer = SummaryWriter() writer_test = SummaryWriter() counter_test = 0 counter_play = 0 start_epoch = 0 mean_miou = AverageMeter() mean_loss = AverageMeter() mean_acc = AverageMeter() best_value = 0 for epoch in range(start_epoch, opt.total_epochs): opt.epoch += 1 model.train() total_seen_class = [0 for _ in range(opt.n_classes)] total_correct_class = [0 for _ in range(opt.n_classes)] total_iou_deno_class = [0 for _ in range(opt.n_classes)] ave_mIoU = 0 total_correct = 0 total_seen = 0 loss_sum = 0 mean_miou.reset() mean_loss.reset() mean_acc.reset() for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9): # if i % 50 == 0: opt.iter += 1 if not opt.multi_gpus: data = data.to(opt.device) target = data.y batch_label2 = target.cpu().data.numpy() inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1) inputs = inputs[:, :opt.num_channel, :, :] gt = data.y.to(opt.device) # ------------------ zero, output, loss optimizer.zero_grad() out = model(inputs) loss = criterion(out, gt) #pdb.set_trace() # ------------------ optimization loss.backward() optimizer.step() seg_pred= out.transpose(2,1) pred_val = seg_pred.contiguous().cpu().data.numpy() seg_pred = seg_pred.contiguous().view(-1, opt.n_classes) #pdb.set_trace() pred_val = np.argmax(pred_val, 2) batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy() target = target.view(-1, 1)[:, 0] pred_choice = seg_pred.cpu().data.max(1)[1].numpy() correct = np.sum(pred_choice == batch_label) total_correct += correct total_seen += (opt.batch_size *NUM_POINT) loss_sum += loss.item() current_seen_class = [0 for _ in range(opt.n_classes)] current_correct_class = [0 for _ in range(opt.n_classes)] current_iou_deno_class = [0 for _ in range(opt.n_classes)] #pdb.set_trace() for l in range(opt.n_classes): #pdb.set_trace() total_seen_class[l] += np.sum((batch_label2 == l)) total_correct_class[l] += np.sum((pred_val == l) & (batch_label2 == l)) total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label2 == l))) current_seen_class[l] = np.sum((batch_label2 == l)) current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l)) current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l))) #pdb.set_trace() writer.add_scalar('training_loss', loss.item(), counter_play) writer.add_scalar('training_accuracy', correct / float(opt.batch_size * NUM_POINT), counter_play) m_iou = np.mean(np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6)) writer.add_scalar('training_mIoU', m_iou, counter_play) ave_mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6)) # print("training_loss:",loss.item()) # print('training_accuracy:',correct / float(opt.batch_size * NUM_POINT)) # print('training_mIoU:',m_iou) mean_miou.update(m_iou) mean_loss.update(loss.item()) mean_acc.update(correct / float(opt.batch_size * NUM_POINT)) counter_play = counter_play + 1 train_mIoU = mean_miou.avg train_macc = mean_acc.avg train_mloss = mean_loss.avg print('Epoch: %d, Training point avg class IoU: %f' % (epoch,train_mIoU)) print('Epoch: %d, Training mean loss: %f' %(epoch, train_mloss)) print('Epoch: %d, Training accuracy: %f' %(epoch, train_macc)) mean_miou.reset() mean_loss.reset() mean_acc.reset() print('validation_loader') model.eval() with torch.no_grad(): for i, data in tqdm(enumerate(validation_loader), total=len(validation_loader), smoothing=0.9): # if i % 50 ==0: if not opt.multi_gpus: data = data.to(opt.device) target = data.y batch_label2 = target.cpu().data.numpy() inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1) inputs = inputs[:, :opt.num_channel, :, :] gt = data.y.to(opt.device) out = model(inputs) loss = criterion(out, gt) #pdb.set_trace() seg_pred = out.transpose(2, 1) pred_val = seg_pred.contiguous().cpu().data.numpy() seg_pred = seg_pred.contiguous().view(-1, opt.n_classes) pred_val = np.argmax(pred_val, 2) batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy() target = target.view(-1, 1)[:, 0] pred_choice = seg_pred.cpu().data.max(1)[1].numpy() correct = np.sum(pred_choice == batch_label) current_seen_class = [0 for _ in range(opt.n_classes)] current_correct_class = [0 for _ in range(opt.n_classes)] current_iou_deno_class = [0 for _ in range(opt.n_classes)] for l in range(opt.n_classes): current_seen_class[l] = np.sum((batch_label2 == l)) current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l)) current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l))) m_iou = np.mean( np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6)) mean_miou.update(m_iou) mean_loss.update(loss.item()) mean_acc.update(correct / float(opt.batch_size * NUM_POINT)) validation_mIoU = mean_miou.avg validation_macc = mean_acc.avg validation_mloss = mean_loss.avg writer.add_scalar('validation_loss', validation_mloss, epoch) print('Epoch: %d, validation mean loss: %f' %(epoch, validation_mloss)) writer.add_scalar('validation_accuracy', validation_macc, epoch) print('Epoch: %d, validation accuracy: %f' %(epoch, validation_macc)) writer.add_scalar('validation_mIoU', validation_mIoU, epoch) print('Epoch: %d, validation point avg class IoU: %f' % (epoch,validation_mIoU)) model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} package ={ 'epoch': opt.epoch, 'state_dict': model_cpu, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'train_miou':train_mIoU, 'train_accuracy':train_macc, 'train_loss':train_mloss, 'validation_mIoU':validation_mIoU, 'validation_macc':validation_macc, 'validation_mloss':validation_mloss, 'num_channel':opt.num_channel, 'gpuNum': opt.gpuNum, 'time':time.ctime() } torch.save(package,'saves/val_miou_%f_val_acc_%f_%d.pth' % (validation_mIoU, validation_macc, epoch)) is_best = (best_value < validation_mIoU) print('Is Best? ',is_best) if (best_value < validation_mIoU): best_value = validation_mIoU torch.save(package,'saves/best_model.pth') print('Best IoU: %f' % (best_value)) scheduler.step() opt.printer.info('Saving the final model.Finish!')
train_loss += loss.item() scores_train = trainer.get_scores() batch_time.update(time.time() - start_time) start_time = time.time() if i_batch % 10 == 0 and local_rank == 0: tbar.set_description( 'Train loss: %.4f; mIoU: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), scores_train["iou_mean"], data_time.avg, batch_time.avg)) if local_rank == 0: writer.add_scalar('loss', train_loss / len(tbar), epoch) trainer.reset_metrics() data_time.reset() batch_time.reset() if epoch % 1 == 0 and local_rank == 0: with torch.no_grad(): model.eval() print("evaluating...") if test: tbar = tqdm(dataloader_test) else: tbar = tqdm(dataloader_val) start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time)
def train_one_epoch(self, epoch): train_errors = AverageMeter() train_losses = AverageMeter() train_iter = tqdm.tqdm(self.train_loader, desc='Train Epoch', total=self.n_batch_train, leave=False) self.model.train() for i, batch in enumerate(train_iter): image = batch['image'] gaze = batch['gaze'] if self.pose_mode: pose = batch['pose'] out = self.model(image, pose) else: out = self.model(image) num = image.size()[0] gaze_error_batch = np.mean( angular_error(out.cpu().data.numpy(), gaze.cpu().data.numpy())) train_errors.update(gaze_error_batch.item(), num) loss_gaze = self.criterion(out, gaze) self.optimizer.zero_grad() # loss_gaze.backward() accelerator.backward(loss_gaze) self.optimizer.step() train_losses.update(loss_gaze.item(), num) if i % self.config.log_freq == 0: if self.config.wandb: wandb.log({ 'epoch': epoch, "batch": i, "Train Errors": train_errors.avg, "Train Losses": train_losses.avg }) postfix = {'Error': train_errors.avg, 'Loss': train_losses.avg} train_iter.set_postfix(postfix) train_errors.reset() train_losses.reset() if self.use_val: self.model.eval() val_errors = AverageMeter() val_losses = AverageMeter() val_iter = tqdm.tqdm(self.val_loader, desc='Val', total=self.n_batch_val, leave=False) for i, batch in enumerate(val_iter): image = batch['image'] gaze = batch['gaze'] if self.pose_mode: pose = batch['pose'] out = self.model(image, pose) else: out = self.model(image) num = image.size()[0] gaze_error_batch = np.mean( angular_error(out.cpu().data.numpy(), gaze.cpu().data.numpy())) val_errors.update(gaze_error_batch.item(), num) loss_gaze = self.criterion(out, gaze) val_losses.update(loss_gaze.item(), num) if i % self.config.log_freq == 0: postfix = {'Error': val_errors.avg, 'Loss': val_losses.avg} val_iter.set_postfix(postfix) if self.config.wandb: wandb.log({ 'epoch': epoch, "Val Errors": val_errors.avg, "Val Losses": val_losses.avg }) return train_errors.avg, train_losses.avg
def train(self, dataset_train, dataset_val, criterion, optimizer_func, trainer_func, evaluator_func, collate, dataset_test=None, tester_func=None): if self.distributed: sampler_train = DistributedSampler(dataset_train, shuffle=True) dataloader_train = DataLoader(dataset_train, num_workers=self.cfg.num_workers, batch_size=self.cfg.batch_size, collate_fn=collate, sampler=sampler_train, pin_memory=True) else: dataloader_train = DataLoader(dataset_train, num_workers=self.cfg.num_workers, batch_size=self.cfg.batch_size, collate_fn=collate, shuffle=True, pin_memory=True) dataloader_val = DataLoader(dataset_val, num_workers=self.cfg.num_workers, batch_size=self.cfg.batch_size, collate_fn=collate, shuffle=False, pin_memory=True) # if dataset_test: # dataloader_test = DataLoader(dataset_test, num_workers=self.cfg.num_workers, batch_size=self.cfg.batch_size, collate_fn=collate, shuffle=False, pin_memory=True) ################################### print("creating models......") model = self.model_loader(self.model, self.device, distributed=self.distributed, local_rank=self.local_rank, evaluation=True, ckpt_path=self.cfg.ckpt_path) ################################### num_epochs = self.cfg.num_epochs learning_rate = self.cfg.lr data_time = AverageMeter("DataTime", ':3.3f') batch_time = AverageMeter("BatchTime", ':3.3f') optimizer = optimizer_func(model, learning_rate=learning_rate) scheduler = LR_Scheduler(self.cfg.scheduler, learning_rate, num_epochs, len(dataloader_train)) ################################## trainer = trainer_func(criterion, optimizer, self.cfg.n_class) evaluator = evaluator_func(self.cfg.n_class) if tester_func: tester = tester_func(self.cfg.n_class, self.cfg.num_workers, self.cfg.batch_size) evaluation = self.cfg.evaluation val_vis = self.cfg.val_vis best_pred = 0.0 print("start training......") # log if self.local_rank == 0: f_log = open(self.cfg.log_path + self.cfg.task_name + ".log", 'w') log = self.cfg.task_name + '\n' for k, v in self.cfg.__dict__.items(): log += str(k) + ' = ' + str(v) + '\n' print(log) f_log.write(log) f_log.flush() # writer if self.local_rank == 0: writer = SummaryWriter(log_dir=self.cfg.writer_path) writer_info = {} for epoch in range(num_epochs): optimizer.zero_grad() num_batch = len(dataloader_train) tbar = tqdm(dataloader_train) train_loss = 0 start_time = time.time() model.train() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) scheduler(optimizer, i_batch, epoch, best_pred) # loss = trainer.train(sample, model) if self.distributed: loss = trainer.train(sample, model) else: loss = trainer.train_acc(sample, model, i_batch, 2, num_batch) train_loss += loss.item() scores_train = trainer.get_scores() batch_time.update(time.time() - start_time) start_time = time.time() if i_batch % 20 == 0 and self.local_rank == 0: tbar.set_description( 'Train loss: %.4f; mIoU: %.4f; data time: %.2f; batch time: %.2f' % (train_loss / (i_batch + 1), scores_train["iou_mean"], data_time.avg, batch_time.avg)) # break trainer.reset_metrics() data_time.reset() batch_time.reset() train_model_fr, train_seg_fr = trainer.calculate_avg_fr() if evaluation and epoch % 1 == 0 and self.local_rank == 0: with torch.no_grad(): model.eval() ##--** evaluating **-- print("evaluating...") tbar = tqdm(dataloader_val) start_time = time.time() for i_batch, sample in enumerate(tbar): data_time.update(time.time() - start_time) predictions = evaluator.eval(sample, model) scores_val = evaluator.get_scores() batch_time.update(time.time() - start_time) if i_batch % 20 == 0 and self.local_rank == 0: tbar.set_description( 'mIoU: %.4f; data time: %.2f; batch time: %.2f' % (scores_val["iou_mean"], data_time.avg, batch_time.avg)) if val_vis and ( 1 + epoch) % 10 == 0: # val set result visualize for i in range(len(sample['id'])): name = sample['id'][i] + '.png' slide = name.split('_')[0] slide_dir = os.path.join( self.cfg.val_output_path, slide) if not os.path.exists(slide_dir): os.makedirs(slide_dir) predictions_rgb = class_to_RGB(predictions[i]) predictions_rgb = cv2.cvtColor( predictions_rgb, cv2.COLOR_BGR2RGB) cv2.imwrite(os.path.join(slide_dir, name), predictions_rgb) # writer_info.update(mask=mask_rgb, prediction=predictions_rgb) start_time = time.time() # break data_time.reset() batch_time.reset() scores_val = evaluator.get_scores() evaluator.reset_metrics() val_model_fr, val_seg_fr = evaluator.calculate_avg_fr() ##--** testing **-- if dataset_test: print("testing...") num_slides = len(dataset_test.slides) tbar2 = tqdm(range(num_slides)) start_time = time.time() for i in tbar2: dataset_test.get_patches_from_index(i) data_time.update(time.time() - start_time) predictions, output, _ = tester.inference( dataset_test, model) mask = dataset_test.get_slide_mask_from_index(i) tester.update_scores(mask, predictions) scores_test = tester.get_scores() batch_time.update(time.time() - start_time) tbar2.set_description( 'mIoU: %.4f; data time: %.2f; slide time: %.2f' % (scores_test["iou_mean"], data_time.avg, batch_time.avg)) output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) cv2.imwrite( os.path.join(self.cfg.test_output_path, dataset_test.slide + '.png'), output) # writer_info.update(mask=mask_rgb, prediction=predictions_rgb) start_time = time.time() # break data_time.reset() batch_time.reset() scores_test = tester.get_scores() tester.reset_metrics() test_model_fr, test_seg_fr = tester.calculate_avg_fr() # save model best_pred = save_ckpt_model(model, self.cfg, scores_val, best_pred, epoch) # log update_log(f_log, self.cfg, scores_train, scores_val, [train_model_fr, train_seg_fr], [val_model_fr, val_seg_fr], epoch, scores_test=scores_test, test_fr=[test_model_fr, test_seg_fr]) # writer\ if self.cfg.n_class == 4: writer_info.update(loss=train_loss / len(tbar), lr=optimizer.param_groups[0]['lr'], mIOU={ "train": scores_train["iou_mean"], "val": scores_val["iou_mean"], "test": scores_test["iou_mean"], }, mucosa_iou={ "train": scores_train["iou"][2], "val": scores_val["iou"][2], "test": scores_test["iou"][2], }, tumor_iou={ "train": scores_train["iou"][3], "val": scores_val["iou"][3], "test": scores_test["iou"][3], }, mucosa_model_fr={ "train": train_model_fr[0], "val": val_model_fr[0], "test": test_model_fr[0], }, tumor_model_fr={ "train": train_model_fr[1], "val": val_model_fr[1], "test": val_model_fr[1], }, mucosa_seg_fr={ "train": train_seg_fr[0], "val": val_seg_fr[0], "test": test_seg_fr[0], }, tumor_seg_fr={ "train": train_seg_fr[1], "val": val_seg_fr[1], "test": test_seg_fr[1], }) else: writer_info.update(loss=train_loss / len(tbar), lr=optimizer.param_groups[0]['lr'], mIOU={ "train": scores_train["iou_mean"], "val": scores_val["iou_mean"], "test": scores_test["iou_mean"], }, merge_iou={ "train": scores_train["iou"][2], "val": scores_val["iou"][2], "test": scores_test["iou"][2], }, merge_model_fr={ "train": train_model_fr[0], "val": val_model_fr[0], "test": test_model_fr[0], }, merge_seg_fr={ "train": train_seg_fr[0], "val": val_seg_fr[0], "test": val_seg_fr[0], }) update_writer(writer, writer_info, epoch) if self.local_rank == 0: f_log.close()