opt_seg.ft_resume = "/home/chenwy/DynamicLightEnlighten/bdd/bdd100k_seg/fcn_model/res50_di_180px_L100.255/model_best.pth.tar" opt_seg.eval = True seg = get_segmentation_model( opt_seg.model, dataset=opt_seg.dataset, backbone=opt_seg.backbone, aux=False, se_loss=False, dilated=opt_seg.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=720, crop_size=180, multi_grid=False, multi_dilation=False) seg = DataParallelModel(seg).cuda() seg.eval() if opt_seg.ft: checkpoint = torch.load(opt_seg.ft_resume) seg.module.load_state_dict(checkpoint['state_dict'], strict=False) # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) #################################################### opt = TestOptions().parse() opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip data_loader = CreateDataLoader(opt) dataset = data_loader.load_data()
def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) self.writer = SummaryWriter(log_dir=os.path.join(args.log_root, args.log_name, time.strftime("%Y-%m-%d-%H-%M",time.localtime()))) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) # self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.valloader = data.DataLoader(testset, batch_size=1, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, dilated=args.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},] if hasattr(model, 'head'): params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10}) if hasattr(model, 'auxlayer'): params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.model == 'danet': self.criterion = SegmentationMultiLosses(nclass=self.nclass) elif args.model == 'fcn': self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) else: self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args)
class Trainer(): def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = { 'num_workers': args.workers, 'pin_memory': True } if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, dilated=args.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': 1 * args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': 1 * args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': 1 * args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # self.criterion = SegmentationMultiLosses(nclass=self.nclass) self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) # self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args) def training(self, epoch): train_loss = 0.0 ################################################ self.model.train() ################################################ tbar = tqdm(self.trainloader) # for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar): for i, (image, target, name) in enumerate(tbar): # weather = weather.cuda(); timeofday = timeofday.cuda() ################################################ self.scheduler(self.optimizer, i, epoch, self.best_pred) ################################################ self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) # outputs, weather_o, timeofday_o = self.model(image) outputs = self.model(image) # create weather / timeofday target mask ####################### # b, _, h, w = weather_o.size() # weather_t = torch.ones((b, h, w)).long().cuda() # for bi in range(b): weather_t[bi] *= weather[bi] # timeofday_t = torch.ones((b, h, w)).long().cuda() # for bi in range(b): timeofday_t[bi] *= timeofday[bi] ################################################################ # loss = self.criterion(weather_o, weather_t) + self.criterion(timeofday_o, timeofday_t) loss = self.criterion(outputs, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.logger.info('Train loss: %.3f' % (train_loss / (i + 1))) # save checkpoint every 5 epoch # is_best = False # if epoch % 5 == 0: # # filename = "checkpoint_%s.pth.tar"%(epoch+1) # filename = "checkpoint_%s.%s.%s.%s.pth.tar"%(self.args.log_root, self.args.checkname, self.args.model, epoch+1) # utils.save_checkpoint({ # 'epoch': epoch + 1, # 'state_dict': self.model.module.state_dict(), # 'optimizer': self.optimizer.state_dict(), # 'best_pred': self.best_pred, # }, self.args, is_best, filename) def validation(self, epoch=None): # Fast test during the training # def eval_batch(model, image, target, weather, timeofday, scene): def eval_batch(model, image, target): # outputs, weather_o, timeofday_o = model(image) outputs = model(image) # Gathers tensors from different GPUs on a specified device # outputs = gather(outputs, 0, dim=0) pred = outputs[0] # create weather / timeofday target mask ####################### # b, _, h, w = weather_o.size() # weather_t = torch.ones((b, h, w)).long() # for bi in range(b): weather_t[bi] *= weather[bi] # timeofday_t = torch.ones((b, h, w)).long() # for bi in range(b): timeofday_t[bi] *= timeofday[bi] ################################################################ # self.confusion_matrix_weather.update([ m.astype(np.int64) for m in weather_t.numpy() ], weather_o.cpu().numpy().argmax(1)) # self.confusion_matrix_timeofday.update([ m.astype(np.int64) for m in timeofday_t.numpy() ], timeofday_o.cpu().numpy().argmax(1)) correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) # correct_weather, labeled_weather = utils.batch_pix_accuracy(weather_o.data, weather_t) # correct_timeofday, labeled_timeofday = utils.batch_pix_accuracy(timeofday_o.data, timeofday_t) # return correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 total_correct_weather = 0 total_label_weather = 0 total_correct_timeofday = 0 total_label_timeofday = 0 name2inter = {} name2union = {} tbar = tqdm(self.valloader, desc='\r') # for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar): for i, (image, target, name) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) # correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): # correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene) correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() name2inter[name[0]] = inter.tolist() name2union[name[0]] = union.tolist() # total_correct_weather += correct_weather # total_label_weather += labeled_weather # pixAcc_weather = 1.0 * total_correct_weather / (np.spacing(1) + total_label_weather) # total_correct_timeofday += correct_timeofday # total_label_timeofday += labeled_timeofday # pixAcc_timeofday = 1.0 * total_correct_timeofday / (np.spacing(1) + total_label_timeofday) # tbar.set_description('pixAcc: %.2f, mIoU: %.2f, weather: %.2f, timeofday: %.2f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday)) tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU)) # self.logger.info('pixAcc: %.3f, mIoU: %.3f, pixAcc_weather: %.3f, pixAcc_timeofday: %.3f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday)) self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) with open("name2inter", 'w') as fp: json.dump(name2inter, fp) with open("name2union", 'w') as fp: json.dump(name2union, fp) # cm = self.confusion_matrix_weather.get_scores()['cm'] # self.logger.info(str(cm)) # self.confusion_matrix_weather.reset() # cm = self.confusion_matrix_timeofday.get_scores()['cm'] # self.logger.info(str(cm)) # self.confusion_matrix_timeofday.reset() if epoch is not None: new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) self.writer = SummaryWriter(log_dir=os.path.join(args.log_root, args.log_name, time.strftime("%Y-%m-%d-%H-%M",time.localtime()))) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) # self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.valloader = data.DataLoader(testset, batch_size=1, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, dilated=args.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},] if hasattr(model, 'head'): params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10}) if hasattr(model, 'auxlayer'): params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.model == 'danet': self.criterion = SegmentationMultiLosses(nclass=self.nclass) elif args.model == 'fcn': self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) else: self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args) def training(self, epoch): train_loss = 0.0 ################################################ self.model.train() ################################################ tbar = tqdm(self.trainloader) self.optimizer.zero_grad() for i, (image, target, name, class_freq) in enumerate(tbar): ################################################ self.scheduler(self.optimizer, i, epoch, self.best_pred) ################################################ if torch_ver == "0.3": image = Variable(image) target = Variable(target) r,g,b = image[:, 0, :, :]+1, image[:, 1, :, :]+1, image[:, 2, :, :]+1 gray = 1. - (0.299*r+0.587*g+0.114*b)/2. # h, w gray = gray.unsqueeze(1) with torch.no_grad(): fake_B, _, _ = gan.netG_A.forward(image, gray) outputs = self.model(fake_B.clamp(-1, 1)) loss = self.criterion(outputs, target) loss.backward() if epoch % self.args.late_update == 0: self.optimizer.step() self.optimizer.zero_grad() train_loss += loss.item() tbar.set_description('Train loss: %.3f; ' % (train_loss / (i + 1)) + '[' + ' '.join([ "%.2f"%p for p in np.round(class_freq[0], 2)]) + ']') self.logger.info('Train loss: %.3f' % (train_loss / (i + 1))) # save checkpoint every 5 epoch # is_best = False # if epoch % 5 == 0: # # filename = "checkpoint_%s.pth.tar"%(epoch+1) # filename = "checkpoint_%s.%s.%s.%s.pth.tar"%(self.args.log_root, self.args.checkname, self.args.model, epoch+1) # utils.save_checkpoint({ # 'epoch': epoch + 1, # 'state_dict': self.model.module.state_dict(), # 'optimizer': self.optimizer.state_dict(), # 'best_pred': self.best_pred, # }, self.args, is_best, filename) def validation(self, epoch=None): # Fast test during the training def eval_batch(model, image, target): r,g,b = image[:, 0, :, :]+1, image[:, 1, :, :]+1, image[:, 2, :, :]+1 gray = 1. - (0.299*r+0.587*g+0.114*b)/2. # h, w gray = gray.unsqueeze(1) with torch.no_grad(): fake_B, _, _ = gan.netG_A.forward(image, gray) outputs = self.model(fake_B.clamp(-1, 1)) # Gathers tensors from different GPUs on a specified device # outputs = gather(outputs, 0, dim=0) pred = outputs[0] pred = F.upsample(pred, size=(target.size(1), target.size(2)), mode='bilinear') correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union(pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 name2inter = {}; name2union = {} tbar = tqdm(self.valloader, desc='\r') for i, (image, target, name, class_freq) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch(self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch(self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() name2inter[name[0]] = inter.tolist() name2union[name[0]] = union.tolist() tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU)) self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) self.writer.add_scalars('IoU', {'validation iou': mIoU}, epoch) with open("name2inter", 'w') as fp: json.dump(name2inter, fp) with open("name2union", 'w') as fp: json.dump(name2union, fp) # cm = self.confusion_matrix_weather.get_scores()['cm'] # self.logger.info(str(cm)) # self.confusion_matrix_weather.reset() # cm = self.confusion_matrix_timeofday.get_scores()['cm'] # self.logger.info(str(cm)) # self.confusion_matrix_timeofday.reset() if epoch is not None: new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) self.writer = SummaryWriter(log_dir=os.path.join( args.log_root, args.log_name, time.strftime("%Y-%m-%d-%H-%M", time.localtime()))) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = { 'num_workers': args.workers, 'pin_memory': True } if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) # self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.valloader = data.DataLoader(testset, batch_size=1, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, dilated=args.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': 1 * args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': 1 * args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': 1 * args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.model == 'danet': self.criterion = SegmentationMultiLosses(nclass=self.nclass) elif args.model == 'fcn': self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) else: self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args) def training(self, epoch): train_loss = 0.0 ################################################ self.model.train() ################################################ tbar = tqdm(self.trainloader) # for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar): self.optimizer.zero_grad() for i, (image, target, name, class_freq) in enumerate(tbar): # weather = weather.cuda(); timeofday = timeofday.cuda() ################################################ self.scheduler(self.optimizer, i, epoch, self.best_pred) ################################################ if torch_ver == "0.3": image = Variable(image) target = Variable(target) # outputs, weather_o, timeofday_o = self.model(image) outputs = self.model(image) # create weather / timeofday target mask ####################### # b, _, h, w = weather_o.size() # weather_t = torch.ones((b, h, w)).long().cuda() # for bi in range(b): weather_t[bi] *= weather[bi] # timeofday_t = torch.ones((b, h, w)).long().cuda() # for bi in range(b): timeofday_t[bi] *= timeofday[bi] ################################################################ # loss = self.criterion(weather_o, weather_t) + self.criterion(timeofday_o, timeofday_t) loss = self.criterion(outputs, target) loss.backward() if epoch % self.args.late_update == 0: self.optimizer.step() self.optimizer.zero_grad() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.logger.info( 'Train loss: %.3f; ' % (train_loss / (i + 1)) + '[' + ' '.join(["%.2f" % p for p in np.round(class_freq[0], 2)]) + ']') # save checkpoint every 5 epoch # is_best = False # if epoch % 5 == 0: # # filename = "checkpoint_%s.pth.tar"%(epoch+1) # filename = "checkpoint_%s.%s.%s.%s.pth.tar"%(self.args.log_root, self.args.checkname, self.args.model, epoch+1) # utils.save_checkpoint({ # 'epoch': epoch + 1, # 'state_dict': self.model.module.state_dict(), # 'optimizer': self.optimizer.state_dict(), # 'best_pred': self.best_pred, # }, self.args, is_best, filename) def validation(self, epoch=None): # Fast test during the training size_p = (1000, 1000) sub_batch_size = 5 def eval_batch(model, image, target): if image.size(2) * image.size(3) <= 2250000: # 1500x1500 outputs = model(image) # Gathers tensors from different GPUs on a specified device # outputs = gather(outputs, 0, dim=0) pred = outputs[0] pred = F.upsample( pred, size=(target.size(1), target.size(2)), mode='bilinear' ) # if you downsampled the input image due to large size correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union else: patches, coordinates, sizes = global2patch(image, size_p) predicted_patches = [ torch.zeros(len(coordinates[i]), self.nclass, size_p[0], size_p[1]) for i in range(len(image)) ] for i in range(len(image)): j = 0 while j < len(coordinates[i]): outputs = model(patches[i][j:j + sub_batch_size])[0] predicted_patches[i][j:j + outputs.size()[0]] = outputs j += sub_batch_size pred = patch2global( predicted_patches, self.nclass, sizes, coordinates, size_p) # merge softmax scores from patches (overlaps) inter, union, correct, labeled = 0, 0, 0, 0 for i in range(len(image)): correct_tmp, labeled_tmp = utils.batch_pix_accuracy( pred[i].unsqueeze(0), target[i]) inter_tmp, union_tmp = utils.batch_intersection_union( pred[i].unsqueeze(0), target[i], self.nclass) correct += correct_tmp labeled += labeled_tmp inter += inter_tmp union += union_tmp return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 name2inter = {} name2union = {} tbar = tqdm(self.valloader, desc='\r') for i, (image, target, name, class_freq) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() name2inter[name[0]] = inter.tolist() name2union[name[0]] = union.tolist() tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU)) self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) self.writer.add_scalars('IoU', {'validation iou': mIoU}, epoch) with open("name2inter", 'w') as fp: json.dump(name2inter, fp) with open("name2union", 'w') as fp: json.dump(name2union, fp) torch.cuda.empty_cache() if epoch is not None: new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)