class Trainer(object): def __init__(self, config, args): self.args = args self.config = config self.visdom = args.visdom if args.visdom: self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888) # Define Dataloader self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( config) self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader( config) # Define network self.model = DeepLab(num_classes=self.nclass, backbone=config.backbone, output_stride=config.out_stride, sync_bn=config.sync_bn, freeze_bn=config.freeze_bn) self.D = Discriminator(num_classes=self.nclass, ndf=16) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': config.lr }, { 'params': self.model.get_10x_lr_params(), 'lr': config.lr * 10 }] # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum, weight_decay=config.weight_decay) self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=config.lr, betas=(0.9, 0.99)) # Define Criterion # whether to use class balanced weights self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=config.loss) #self.model, self.optimizer = model, optimizer self.entropy_mini_loss = MinimizeEntropyLoss() #self.batchloss = BatchLoss() self.bottleneck_loss = BottleneckLoss() # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr, config.epochs, len(self.train_loader), config.lr_step, config.warmup_epochs) # labels for adversarial training self.source_label = 0 self.target_label = 1 # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # cudnn.benchmark = True self.model = self.model.cuda() self.D = torch.nn.DataParallel(self.D) patch_replication_callback(self.D) self.D = self.D.cuda() self.best_pred_source = 0.0 self.best_pred_target = 0.0 self.bn_loss = 100 # Resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) if args.cuda: self.model.module.load_state_dict(checkpoint) else: self.model.load_state_dict(checkpoint, map_location=torch.device('cpu')) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch)) def training(self, epoch): train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) target_train_iterator = iter(self.target_train_loader) for i, sample in enumerate(tbar): itr = epoch * len(self.train_loader) + i if self.visdom: self.vis.line( X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]), win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'), update='append' if itr > 0 else None) A_image, A_target = sample['image'], sample['label'] # Get one batch from target domain try: target_sample = next(target_train_iterator) except StopIteration: target_train_iterator = iter(self.target_train_loader) target_sample = next(target_train_iterator) B_image, B_target = target_sample['image'], target_sample['label'] if self.args.cuda: A_image, A_target = A_image.cuda(), A_target.cuda() B_image, B_target = B_image.cuda(), B_target.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred_source, self.best_pred_target) self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source, self.best_pred_target) A_output, A_feat, A_low_feat = self.model(A_image) B_output, B_feat, B_low_feat = self.model(B_image) self.optimizer.zero_grad() self.D_optimizer.zero_grad() # Train seg network for param in self.D.parameters(): param.requires_grad = False # Supervised loss seg_loss = self.criterion(A_output, A_target) # Unsupervised bn loss bottleneck_loss = self.bottleneck_loss.loss( A_feat, B_feat) + self.bottleneck_loss.loss( A_low_feat, B_low_feat) # Unsupervised entropy minimization loss #entropy_mini_loss = self.entropy_mini_loss.loss(B_output) #main_loss = seg_loss + bottleneck_loss*100 main_loss = seg_loss # Train adversarial loss D_out = self.D(F.softmax(B_output)) adv_loss = bce_loss(D_out, self.source_label) #adv_loss.backward() main_loss += self.config.lambda_adv * adv_loss main_loss.backward() # Train discriminator for param in self.D.parameters(): param.requires_grad = True A_output_detach = A_output.detach() B_output_detach = B_output.detach() # source D_source = self.D(F.softmax(A_output_detach)) source_loss = bce_loss(D_source, self.source_label) source_loss = source_loss / 2 #source_loss.backward() # target D_target = self.D(F.softmax(B_output_detach)) target_loss = bce_loss(D_target, self.target_label) target_loss = target_loss / 2 #target_loss.backward() d_loss = source_loss + target_loss d_loss.backward() self.optimizer.step() self.D_optimizer.step() seg_loss_sum += seg_loss.item() bn_loss_sum += bottleneck_loss.item() adv_loss_sum += self.config.lambda_adv * adv_loss.item() d_loss_sum += d_loss.item() train_loss += seg_loss.item( ) + self.config.lambda_adv * adv_loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + A_image.data.shape[0])) #print('Loss: %.3f' % train_loss) print('Seg Loss: %.3f' % seg_loss_sum) print('BN Loss: %.3f' % bn_loss_sum) print('Adv Loss: %.3f' % adv_loss_sum) print('Discriminator Loss: %.3f' % d_loss_sum) if self.visdom: self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) def validation(self, epoch): def get_metrics(tbar, if_source=False): self.evaluator.reset() test_loss = 0.0 feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0 adv_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output, low_feat, feat = self.model(image) low_feat = low_feat.cpu().numpy() feat = feat.cpu().numpy() #from IPython import embed #embed() if isinstance(feat, np.ndarray): feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1) low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean( axis=1) feat_var += feat.var(axis=0).var(axis=1).var(axis=1) low_feat_var += low_feat.var(axis=0).var(axis=1).var( axis=1) else: feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1) low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean( axis=1) feat_var = feat.var(axis=0).var(axis=1).var(axis=1) low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1) d_output = self.D(F.softmax(output)) adv_loss += bce_loss(d_output, self.source_label).item() loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) feat_mean /= (i + 1) low_feat_mean /= (i + 1) feat_var /= (i + 1) low_feat_var /= (i + 1) adv_loss /= (i + 1) # Fast test during the training Acc = self.evaluator.Building_Acc() IoU = self.evaluator.Building_IoU() mIoU = self.evaluator.Mean_Intersection_over_Union() if if_source: print('Validation on source:') else: print('Validation on target:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + image.data.shape[0])) print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU)) print('Loss: %.3f' % test_loss) print('Adv Loss: %.3f' % adv_loss) # Draw Visdom if if_source: names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] else: names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] if self.visdom: self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0], update='append') self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss', update='append') self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1], opts=dict(title='metrics', xlabel='epoch', ylabel='performance'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2], update='append') self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3], update='append') return Acc, IoU, mIoU, feat_mean, low_feat_mean, feat_var, low_feat_var, adv_loss self.model.eval() tbar_source = tqdm(self.val_loader, desc='\r') tbar_target = tqdm(self.target_val_loader, desc='\r') s_acc, s_iou, s_miou, s_m, s_lm, s_v, s_lv, s_adv = get_metrics( tbar_source, True) t_acc, t_iou, t_miou, t_m, t_lm, t_v, t_lv, t_adv = get_metrics( tbar_target, False) new_pred_source = s_iou new_pred_target = t_iou bn_loss = np.abs(s_m - t_m).mean() + np.abs(s_lm - t_lm).mean( ) + np.abs(s_v - t_v).mean() + np.abs(s_lv - t_lv).mean() bn_loss = bn_loss.astype('float64') #if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target: if new_pred_source > self.best_pred_source or bn_loss < self.bn_loss: is_best = True self.best_pred_source = max(new_pred_source, self.best_pred_source) #self.best_pred_target = max(new_pred_target, self.best_pred_target) self.bn_loss = min(bn_loss, self.bn_loss) print('Saving state, epoch:', epoch) torch.save( self.model.module.state_dict(), self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth') loss_file = { 's_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou, 't_Acc': t_acc, 't_IoU': t_iou, 't_mIoU': t_miou, 'bn_loss': bn_loss, 's_adv': s_adv, 't_adv': t_adv } with open( os.path.join(self.args.save_folder, 'eval', 'epoch' + str(epoch) + '.json'), 'w') as f: json.dump(loss_file, f)
def __init__(self, config, args): self.args = args self.config = config self.visdom = args.visdom if args.visdom: self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888) # Define Dataloader self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( config) self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader( config) # Define network self.model = DeepLab(num_classes=self.nclass, backbone=config.backbone, output_stride=config.out_stride, sync_bn=config.sync_bn, freeze_bn=config.freeze_bn) self.D = Discriminator(num_classes=self.nclass, ndf=16) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': config.lr }, { 'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio }] # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum, weight_decay=config.weight_decay) self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=config.lr, betas=(0.9, 0.99)) # Define Criterion # whether to use class balanced weights self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=config.loss) self.entropy_mini_loss = MinimizeEntropyLoss() self.bottleneck_loss = BottleneckLoss() self.instance_loss = InstanceLoss() # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr, config.epochs, len(self.train_loader), config.lr_step, config.warmup_epochs) self.summary = TensorboardSummary('./train_log') # labels for adversarial training self.source_label = 0 self.target_label = 1 # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # cudnn.benchmark = True self.model = self.model.cuda() self.D = torch.nn.DataParallel(self.D) patch_replication_callback(self.D) self.D = self.D.cuda() self.best_pred_source = 0.0 self.best_pred_target = 0.0 # Resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) if args.cuda: self.model.module.load_state_dict(checkpoint) else: self.model.load_state_dict(checkpoint, map_location=torch.device('cpu')) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch))
class Trainer(object): def __init__(self, config, args): self.args = args self.config = config self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888) # Define Dataloader self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( config) self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader( config) # Define network model = DeepLab(num_classes=self.nclass, backbone=config.backbone, output_stride=config.out_stride, sync_bn=config.sync_bn, freeze_bn=config.freeze_bn) train_params = [{ 'params': model.get_1x_lr_params(), 'lr': config.lr }, { 'params': model.get_10x_lr_params(), 'lr': config.lr * 10 }] # Define Optimizer optimizer = torch.optim.SGD(train_params, momentum=config.momentum, weight_decay=config.weight_decay) # Define Criterion # whether to use class balanced weights self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=config.loss) self.model, self.optimizer = model, optimizer self.entropy_mini_loss = MinimizeEntropyLoss() #self.batchloss = BatchLoss() self.bottleneck_loss = BottleneckLoss() # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr, config.epochs, len(self.train_loader), config.lr_step, config.warmup_epochs) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # cudnn.benchmark = True self.model = self.model.cuda() self.best_pred_source = 0.0 self.best_pred_target = 0.0 # Resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) if args.cuda: self.model.module.load_state_dict(checkpoint) else: self.model.load_state_dict(checkpoint, map_location=torch.device('cpu')) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch)) def training(self, epoch): train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum = 0.0, 0.0, 0.0, 0.0 self.model.train() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) target_train_iterator = iter(self.target_train_loader) for i, sample in enumerate(tbar): itr = epoch * len(self.train_loader) + i self.vis.line(X=torch.tensor([itr]), Y=torch.tensor( [self.optimizer.param_groups[0]['lr']]), win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'), update='append' if itr > 0 else None) A_image, A_target = sample['image'], sample['label'] # Get one batch from target domain try: target_sample = next(target_train_iterator) except StopIteration: target_train_iterator = iter(self.target_train_loader) target_sample = next(target_train_iterator) B_image, B_target = target_sample['image'], target_sample['label'] if self.args.cuda: A_image, A_target = A_image.cuda(), A_target.cuda() B_image, B_target = B_image.cuda(), B_target.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred_source, self.best_pred_target) # Supervised loss self.optimizer.zero_grad() #print(A_image.size()) A_output, A_feat, A_low_feat = self.model(A_image) #A_bn_mean, A_bn_var = self.model.module.get_bn_parameter() seg_loss = self.criterion(A_output, A_target) #seg_loss.backward() #self.optimizer.step() # Unsupervised bn loss #self.optimizer.zero_grad() B_output, B_feat, B_low_feat = self.model(B_image) #B_bn_mean, B_bn_var = self.model.module.get_bn_parameter() #mean_loss = self.batchloss.loss(A_bn_mean, B_bn_mean) #var_loss = self.batchloss.loss(A_bn_var, B_bn_var) #bn_loss = mean_loss + var_loss #bn_loss.requires_grad = True #bn_loss.backward() bottleneck_loss = self.bottleneck_loss.loss( A_feat, B_feat) + self.bottleneck_loss.loss( A_low_feat, B_low_feat) #bottleneck_loss.backward() #self.optimizer.step() # Unsupervised entropy minimization loss #self.optimizer.zero_grad() entropy_mini_loss = self.entropy_mini_loss.loss(B_output) #loss = seg_loss + bottleneck_loss + entropy_mini_loss loss = seg_loss + bottleneck_loss #entropy_mini_loss.backward() loss.backward() self.optimizer.step() seg_loss_sum += seg_loss.item() bn_loss_sum += bottleneck_loss.item() entropy_loss_sum += entropy_mini_loss.item() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + A_image.data.shape[0])) print('Loss: %.3f' % train_loss) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([entropy_loss_sum]), win='train_loss', name='Entropy_loss', opts=dict(title='loss', xlabel='epoch', ylabel='loss'), update='append' if epoch > 0 else None) def validation(self, epoch): def get_metrics(tbar, if_source=False): test_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output, _, _ = self.model(image) loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) # Fast test during the training Acc = self.evaluator.Building_Acc() IoU = self.evaluator.Building_IoU() mIoU = self.evaluator.Mean_Intersection_over_Union() if if_source: print('Validation on source:') else: print('Validation on target:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + image.data.shape[0])) print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU)) print('Loss: %.3f' % test_loss) # Draw Visdom if if_source: names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] else: names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0], update='append') self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1], opts=dict(title='metrics', xlabel='epoch', ylabel='performance'), update='append' if epoch > 0 else None) self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2], update='append') self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3], update='append') return Acc, IoU, mIoU self.model.eval() self.evaluator.reset() tbar_source = tqdm(self.val_loader, desc='\r') tbar_target = tqdm(self.target_val_loader, desc='\r') s_acc, s_iou, s_miou = get_metrics(tbar_source, True) t_acc, t_iou, t_miou = get_metrics(tbar_target, False) new_pred_source = s_iou new_pred_target = t_iou if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target: is_best = True self.best_pred_source = max(new_pred_source, self.best_pred_source) self.best_pred_target = max(new_pred_target, self.best_pred_target) print('Saving state, epoch:', epoch) torch.save( self.model.module.state_dict(), self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth') loss_file = { 's_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou, 't_Acc': t_acc, 't_IoU': t_iou, 't_mIoU': t_miou } with open( os.path.join(self.args.save_folder, 'eval', 'epoch' + str(epoch) + '.json'), 'w') as f: json.dump(loss_file, f)