class Trainer(object): def __init__(self, args): self.args = args # Define Saver # self.saver = Saver(args) # Recoder the running processing self.saver = Saver(args) sys.stdout = Logger( os.path.join( self.saver.experiment_dir, 'log_train-%s.txt' % time.strftime("%Y-%m-%d-%H-%M-%S"))) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) if args.dataset == 'pairwise_lits': proxy_nclasses = self.nclass = 3 elif args.dataset == 'pairwise_chaos': proxy_nclasses = 2 * self.nclass else: raise NotImplementedError # Define network model = ConsistentDeepLab(in_channels=3, num_classes=proxy_nclasses, pretrained=args.pretrained, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) train_params = [{ 'params': model.get_1x_lr_params(), 'lr': args.lr }, { 'params': model.get_10x_lr_params(), 'lr': args.lr * 10 }] # Define Optimizer # optimizer = torch.optim.SGD(train_params, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: weights = calculate_weigths_labels(args.dataset, self.train_loader, proxy_nclasses) else: weights = None # Initializing loss print("Initializing loss: {}".format(args.loss_type)) self.criterion = losses.init_loss(args.loss_type, weights=weights) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) for i, (sample1, sample2, proxy_label, sample_indices) in enumerate(tbar): image1, target1 = sample1['image'], sample1['label'] image2, target2 = sample2['image'], sample2['label'] if self.args.cuda: image1, target1 = image1.cuda(), target1.cuda() image2, target2 = image2.cuda(), target2.cuda() proxy_label = proxy_label.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() output = self.model(image1, image2) loss = self.criterion(output, proxy_label) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) # Show 10 * 3 inference results each epoch if i % (num_img_tr // 10) == 0: global_step = i + num_img_tr * epoch image = torch.cat((image1, image2), dim=-2) if len(proxy_label.shape) > 3: output = output[:, 0:self.nclass] proxy_label = torch.argmax(proxy_label[:, 0:self.nclass], dim=1) self.summary.visualize_image(self.writer, self.args.dataset, image, proxy_label, output, global_step) self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image1.data.shape[0])) print('Loss: %.3f' % train_loss) if self.args.no_val: # save checkpoint every epoch is_best = False self.saver.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, is_best) def validation(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc='\r') test_loss = 0.0 val_time = 0 for i, (sample1, sample2, proxy_label, sample_indices) in enumerate(tbar): image1, target1 = sample1['image'], sample1['label'] image2, target2 = sample2['image'], sample2['label'] if self.args.cuda: image1, target1 = image1.cuda(), target1.cuda() image2, target2 = image2.cuda(), target2.cuda() proxy_label = proxy_label.cuda() with torch.no_grad(): start = time.time() output = self.model(image1, image2, is_val=True) end = time.time() val_time += end - start loss = self.criterion(output, proxy_label) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() proxy_label = proxy_label.cpu().numpy() # Add batch sample into evaluator if len(proxy_label.shape) > 3: pred = np.argmax(pred[:, 0:self.nclass], axis=1) proxy_label = np.argmax(proxy_label[:, 0:self.nclass], axis=1) else: pred = np.argmax(pred, axis=1) self.evaluator.add_batch(proxy_label, pred) if self.args.save_predict: self.saver.save_predict_mask( pred, sample_indices, self.val_loader.dataset.data1_files) print("Val time: {}".format(val_time)) print("Total paramerters: {}".format( sum(x.numel() for x in self.model.parameters()))) if self.args.save_predict: namelist = [] for fname in self.val_loader.dataset.data1_files: # namelist.append(fname.split('/')[-1].split('.')[0]) _, name = os.path.split(fname) name = name.split('.')[0] namelist.append(name) file = gzip.open( os.path.join(self.saver.save_dir, 'namelist.pkl.gz'), 'wb') pickle.dump(namelist, file, protocol=-1) file.close() # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) print('Validation:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image1.data.shape[0])) print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc, Acc_class, mIoU, FWIoU)) print('Loss: %.3f' % test_loss) dice = self.evaluator.Dice() # self.writer.add_scalar('val/Dice_1', dice[1], epoch) self.writer.add_scalar('val/Dice_2', dice[2], epoch) print("Dice:{}".format(dice)) new_pred = mIoU if new_pred > self.best_pred: is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, is_best)
class Trainer(object): def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) # Define network if args.model == 'deeplab': model = DeepLab(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) model_params = [{ 'params': model.get_1x_lr_params(), 'lr': args.lr }, { 'params': model.get_10x_lr_params(), 'lr': args.lr * 10 }] elif args.model == 'unet': model = UNet(n_channels=3, n_classes=self.nclass) model_params = model.parameters() # Define Optimizer optimizer = torch.optim.SGD(model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = os.path.join( Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) #$set_trace() for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if self.args.model == 'deeplab': output, fuse = self.model(image) elif self.args.model == 'unet': output = self.model(image) #print("lalala",output.shape) loss = self.criterion(output, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) # Show 10 * 3 inference results each epoch if i % (num_img_tr // 10) == 0: global_step = i + num_img_tr * epoch self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) print('Loss: %.3f' % train_loss) if self.args.no_val == 'True': # save checkpoint every epoch is_best = False self.saver.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, is_best) def validation(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc='\r') test_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): if self.args.model == 'deeplab': output, fuse = self.model(image) else: output = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() Jaccard = self.evaluator.Jaccard() Dice = self.evaluator.Dice() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) self.writer.add_scalar('val/dice', Dice, epoch) self.writer.add_scalar('val/jaccard', Jaccard, epoch) print('Validation:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc, Acc_class, mIoU, FWIoU)) print('Loss: %.3f' % test_loss) new_pred = mIoU if new_pred > self.best_pred: is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, is_best)