class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, dilated=args.dilated, lateral=args.lateral, jpu=args.jpu, aux=args.aux, se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d, ## BatchNorm2d base_size=args.base_size, crop_size=args.crop_size) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'jpu'): params_list.append({ 'params': model.jpu.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) self.best_pred = 0.0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) ## original loss = self.criterion(outputs, target) loss.backward() ''' ## modified loss criterion = JointEdgeSegLoss(classes=num_classes, ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound, edge_weight=args.edge_weight, seg_weight=args.seg_weight, att_weight=args.att_weight, dual_weight=args.dual_weight) train_main_loss = AverageMeter() train_edge_loss = AverageMeter() train_seg_loss = AverageMeter() train_att_loss = AverageMeter() train_dual_loss = AverageMeter() main_loss = None loss_dict = None self.criterion((seg_out, edge_out), gts) if args.seg_weight > 0: log_seg_loss = loss_dict['seg_loss'].mean().clone().detach_() train_seg_loss.update(log_seg_loss.item(), batch_pixel_size) main_loss = loss_dict['seg_loss'] if args.edge_weight > 0: log_edge_loss = loss_dict['edge_loss'].mean().clone().detach_() train_edge_loss.update(log_edge_loss.item(), batch_pixel_size) if main_loss is not None: main_loss += loss_dict['edge_loss'] else: main_loss = loss_dict['edge_loss'] if args.att_weight > 0: log_att_loss = loss_dict['att_loss'].mean().clone().detach_() train_att_loss.update(log_att_loss.item(), batch_pixel_size) if main_loss is not None: main_loss += loss_dict['att_loss'] else: main_loss = loss_dict['att_loss'] if args.dual_weight > 0: log_dual_loss = loss_dict['dual_loss'].mean().clone().detach_() train_dual_loss.update(log_dual_loss.item(), batch_pixel_size) if main_loss is not None: main_loss += loss_dict['dual_loss'] else: main_loss = loss_dict['dual_loss'] ''' self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
def train(cfg, logger, logdir): # Setup seeds init_seed(11733, en_cudnn=False) # Setup Augmentations train_augmentations = cfg["training"].get("train_augmentations", None) t_data_aug = get_composed_augmentations(train_augmentations) val_augmentations = cfg["validating"].get("val_augmentations", None) v_data_aug = get_composed_augmentations(val_augmentations) # Setup Dataloader path_n = cfg["model"]["path_num"] data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] t_loader = data_loader(data_path,split=cfg["data"]["train_split"],augmentations=t_data_aug,path_num=path_n) v_loader = data_loader(data_path,split=cfg["data"]["val_split"],augmentations=v_data_aug,path_num=path_n) trainloader = data.DataLoader(t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], shuffle=True, drop_last=True ) valloader = data.DataLoader(v_loader, batch_size=cfg["validating"]["batch_size"], num_workers=cfg["validating"]["n_workers"] ) logger.info("Using training seting {}".format(cfg["training"])) # Setup Metrics running_metrics_val = runningScore(t_loader.n_classes) # Setup Model and Loss loss_fn = get_loss_function(cfg["training"]) teacher = get_model(cfg["teacher"], t_loader.n_classes) model = get_model(cfg["model"],t_loader.n_classes, loss_fn, cfg["training"]["resume"],teacher) logger.info("Using loss {}".format(loss_fn)) # Setup optimizer optimizer = get_optimizer(cfg["training"], model) # Setup Multi-GPU model = DataParallelModel(model).cuda() #Initialize training param cnt_iter = 0 best_iou = 0.0 time_meter = averageMeter() while cnt_iter <= cfg["training"]["train_iters"]: for (f_img, labels) in trainloader: cnt_iter += 1 model.train() optimizer.zero_grad() start_ts = time.time() outputs = model(f_img,labels,pos_id=cnt_iter%path_n) seg_loss = gather(outputs, 0) seg_loss = torch.mean(seg_loss) seg_loss.backward() time_meter.update(time.time() - start_ts) optimizer.step() if (cnt_iter + 1) % cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( cnt_iter + 1, cfg["training"]["train_iters"], seg_loss.item(), time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) logger.info(print_str) time_meter.reset() if (cnt_iter + 1) % cfg["training"]["val_interval"] == 0 or (cnt_iter + 1) == cfg["training"]["train_iters"]: model.eval() with torch.no_grad(): for i_val, (f_img_val, labels_val) in tqdm(enumerate(valloader)): outputs = model(f_img_val,pos_id=i_val%path_n) outputs = gather(outputs, 0, dim=0) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) logger.info("{}: {}".format(k, v)) for k, v in class_iou.items(): logger.info("{}: {}".format(k, v)) running_metrics_val.reset() if score["Mean IoU : \t"] >= best_iou: best_iou = score["Mean IoU : \t"] state = { "epoch": cnt_iter + 1, "model_state": clean_state_dict(model.module.state_dict(),'teacher'), "best_iou": best_iou, } save_path = os.path.join(logdir, "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]), ) torch.save(state, save_path)
class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset_1 = get_dataset( 'pascal_voc', root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'), split='train', mode='train', **data_kwargs) trainset_2 = get_dataset( 'pascal_aug', root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'), split='train', mode='train', **data_kwargs) testset = get_dataset( 'pascal_voc', root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'), split='val', mode='val', **data_kwargs) concatenate_trainset = torch.utils.data.ConcatDataset( [trainset_1, trainset_2]) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(concatenate_trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset_1.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm, base_size=args.base_size, crop_size=args.crop_size) # print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.Adam(params_list, lr=args.lr, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) self.best_pred = 0.0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) test_record.append(mIoU) np.save('test_record.npy', test_record) new_pred = (mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
gdal.AllRegister() # Get the model checkpoint = torch.load( r'E:\Project\PyTorch-Encoding\runs\arcs\deeplab\resnest269\model_best.pth.tar\model_best.pth.tar' ) model = get_segmentation_model("deeplab", dataset="arcs", backbone="resnest269", aux=True, se_loss=False, norm_layer=SyncBatchNorm, base_size=128, crop_size=128) model = DataParallelModel(model).cuda() model.module.load_state_dict(checkpoint['state_dict']) model.eval() def processData(tmpName): oriTileDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\OriginTileData" # maskTileDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\MaskTileData" tmpDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\tmpTrainTest" if not os.path.exists(tmpDir): os.makedirs(tmpDir) length = dataSet[tmpName]["length"] for i in range(length): filename = oriTileDir + "\\" + str(i) + ".tif" img = encoding.utils.load_image(filename) img = img.cuda().unsqueeze(0)
class Trainer(): def __init__(self, args): if args.se_loss: args.checkname = args.checkname + "_se" self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225])]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode ='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': False} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone = args.backbone, aux = args.aux, se_loss = args.se_loss, norm_layer = BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size) print(model) # count parameter number pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Total number of parameters: %d"%pytorch_total_params) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},] if hasattr(model, 'head'): if args.diflr: params_list.append({'params': model.head.parameters(), 'lr': args.lr*10}) else: params_list.append({'params': model.head.parameters(), 'lr': args.lr}) if hasattr(model, 'auxlayer'): if args.diflr: params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10}) else: params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer = torch.optim.ASGD(params_list, # lr=args.lr, # weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None and len(args.resume)>0: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: # load weights for the same model # self.model.module.load_state_dict(checkpoint['state_dict']) # model and checkpoint have different strucutures pretrained_dict = checkpoint['state_dict'] model_dict = self.model.module.state_dict() for name, param in pretrained_dict.items(): if name not in model_dict: continue if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data model_dict[name].copy_(param) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader),lr_step=args.lr_step) self.best_pred = 0.0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union(pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch(self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch(self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description( 'pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) new_pred = (pixAcc + mIoU)/2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) #print(model) self.logger.info(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.criterion = SegmentationMultiLosses(nclass=self.nclass) #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.logger.info('Train loss: %.3f' % (train_loss / (i + 1))) if self.args.no_val: # save checkpoint every 10 epoch filename = "checkpoint_%s.pth.tar" % (epoch + 1) is_best = False if epoch > 99: if not epoch % 5: utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best, filename) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
class Tester(): def __init__(self, args): self.args = args self.args.start_epoch = 0 self.args.cuda = True # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.490, .490, .490], [.247, .247, .247])]) # TODO: change mean and std # dataset testset = SegmentationDataset( os.path.join(args.imagelist_path, 'test_stage2.csv'), args.image_path, args.masks_path, input_transform=input_transform, transform_chain=Compose([Resize(self.args.size, self.args.size)], p=1), base_size=480, is_flip=True, is_clahe=True, is_sh_sc_ro=True ) # dataloader kwargs = {'num_workers': args.workers }#, 'pin_memory': True} self.testloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = 1 model = EncNet( nclass=self.nclass, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm ) print(model) self.model = model # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') args.start_epoch = checkpoint['epoch'] state_dict = {k[7:] : v for k,v in checkpoint['state_dict'].items()} self.model.load_state_dict(state_dict) self.best_pred = checkpoint['best_pred'] if 'best_loss' in checkpoint.keys(): self.best_loss = checkpoint['best_loss'] else: self.best_loss = 0 print("=> loaded checkpoint '{}' (epoch {}, best pred: {}, best loss, {})" .format(args.resume, checkpoint['epoch'], self.best_pred, self.best_loss)) self.model = DataParallelModel(self.model).cuda() self.mode2func = { 0 : lambda x, y: (x, y), 1 : apply_hflip, 2 : lambda x, y: (x, y), 3 : lambda x, y: apply_revert_shscro(x, y, angle=5, scale=0.9, dx=0, dy=0), 4 : lambda x, y: apply_revert_shscro(x, y, angle=10, scale=0.9, dx=0, dy=0), 5 : lambda x, y: apply_revert_shscro(x, y, angle=15, scale=0.9, dx=0, dy=0), 6 : lambda x, y: apply_revert_shscro(x, y, angle=20, scale=0.9, dx=0, dy=0), 7 : lambda x, y: apply_revert_shscro(x, y, angle=-5, scale=0.9, dx=0, dy=0), 8 : lambda x, y: apply_revert_shscro(x, y, angle=-10, scale=0.9, dx=0, dy=0), 9 : lambda x, y: apply_revert_shscro(x, y, angle=-15, scale=0.9, dx=0, dy=0), 10 : lambda x, y: apply_revert_shscro(x, y, angle=-20, scale=0.9, dx=0, dy=0), } def predict(self): train_loss = 0.0 self.model.eval() tbar = tqdm(self.testloader) img_ids = [] encode_pixels = [] for i, (img_id, image, _, mode) in enumerate(tbar): image = image.cuda() with torch.no_grad(): outputs = self.model(image) preds_ten = [v[0].data.cpu() for v in outputs] cls_preds_ten = [v[1].data.cpu() for v in outputs] preds_ten = torch.cat(preds_ten) cls_preds_ten = torch.cat(cls_preds_ten) preds = torch.sigmoid(preds_ten).data.cpu().numpy()[:, 0, :, :] mask_pred = torch.sigmoid(cls_preds_ten).data.cpu().numpy().reshape(-1) l_img_id = list(img_id) img_ids += l_img_id for k, imid in enumerate(l_img_id): npy_file = os.path.join(self.args.pred_path, str(imid) + f'_{mode[k].item()}.npy') if mode[k].item() < 2: np.save(npy_file, cv2.resize(self.mode2func[mode[k].item()](preds[k], None)[0], (1024, 1024))) encode_pixels.append(mask_pred[k]) pd.DataFrame({'ImageId' : img_ids, 'EncodedPixels' : encode_pixels}).to_csv( os.path.join(self.args.pred_path, 'stage2_new_model_submit_16.csv'), index=None) def __del__(self): del self.model gc.collect()
class Trainer(): def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class self.confusion_matrix_weather = utils.ConfusionMatrix(7) self.confusion_matrix_timeofday = utils.ConfusionMatrix(4) # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},] if hasattr(model, 'head'): params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10}) if hasattr(model, 'auxlayer'): params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10}) params_list.append({'params': model.weather_classifier.parameters(), 'lr': 0 * args.lr*10}) params_list.append({'params': model.time_classifier.parameters(), 'lr': 0 * args.lr*10}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # self.criterion = SegmentationMultiLosses(nclass=self.nclass) self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) # self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args) def training(self, epoch): train_loss = 0.0 ################################################ self.model.train() self.model.module.weather_classifier.eval() self.model.module.time_classifier.eval() # self.model.eval() # self.model.module.weather_classifier.train() # self.model.module.time_classifier.train() ################################################ tbar = tqdm(self.trainloader) for i, (image, target, weather, timeofday, scene) in enumerate(tbar): weather = weather.cuda(); timeofday = timeofday.cuda() ################################################ # self.scheduler(self.optimizer, i, epoch, self.best_pred) ################################################ self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs, weather_o, timeofday_o = self.model(image) # create weather / timeofday target mask ####################### b, _, h, w = weather_o.size() weather_t = torch.ones((b, h, w)).long().cuda() for bi in range(b): weather_t[bi] *= weather[bi] timeofday_t = torch.ones((b, h, w)).long().cuda() for bi in range(b): timeofday_t[bi] *= timeofday[bi] ################################################################ loss = self.criterion(outputs, target) # loss = self.criterion(weather_o, weather_t) + self.criterion(timeofday_o, timeofday_t) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.logger.info('Train loss: %.3f' % (train_loss / (i + 1))) # save checkpoint every 5 epoch is_best = False if epoch % 5 == 0: # filename = "checkpoint_%s.pth.tar"%(epoch+1) filename = "checkpoint_%s.%s.%s.%s.pth.tar"%(self.args.log_root, self.args.checkname, self.args.model, epoch+1) utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best, filename) def validation(self, epoch=None): # Fast test during the training def eval_batch(model, image, target, weather, timeofday, scene): outputs, weather_o, timeofday_o = model(image) # Gathers tensors from different GPUs on a specified device # outputs = gather(outputs, 0, dim=0) pred = outputs[0] b, _, h, w = weather_o.size() weather_t = torch.ones((b, h, w)).long() for bi in range(b): weather_t[bi] *= weather[bi] timeofday_t = torch.ones((b, h, w)).long() for bi in range(b): timeofday_t[bi] *= timeofday[bi] self.confusion_matrix_weather.update([ m.astype(np.int64) for m in weather_t.numpy() ], weather_o.cpu().numpy().argmax(1)) self.confusion_matrix_timeofday.update([ m.astype(np.int64) for m in timeofday_t.numpy() ], timeofday_o.cpu().numpy().argmax(1)) correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union(pred.data, target, self.nclass) correct_weather, labeled_weather = utils.batch_pix_accuracy(weather_o.data, weather_t) correct_timeofday, labeled_timeofday = utils.batch_pix_accuracy(timeofday_o.data, timeofday_t) return correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 total_correct_weather = 0; total_label_weather = 0; total_correct_timeofday = 0; total_label_timeofday = 0 name2inter = {}; name2union = {} tbar = tqdm(self.valloader, desc='\r') for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene) else: with torch.no_grad(): correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() name2inter[name[0]] = inter.tolist() name2union[name[0]] = union.tolist() total_correct_weather += correct_weather total_label_weather += labeled_weather pixAcc_weather = 1.0 * total_correct_weather / (np.spacing(1) + total_label_weather) total_correct_timeofday += correct_timeofday total_label_timeofday += labeled_timeofday pixAcc_timeofday = 1.0 * total_correct_timeofday / (np.spacing(1) + total_label_timeofday) tbar.set_description('pixAcc: %.2f, mIoU: %.2f, weather: %.2f, timeofday: %.2f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday)) self.logger.info('pixAcc: %.3f, mIoU: %.3f, pixAcc_weather: %.3f, pixAcc_timeofday: %.3f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday)) with open("name2inter", 'w') as fp: json.dump(name2inter, fp) with open("name2union", 'w') as fp: json.dump(name2union, fp) cm = self.confusion_matrix_weather.get_scores()['cm'] self.logger.info(str(cm)) self.confusion_matrix_weather.reset() cm = self.confusion_matrix_timeofday.get_scores()['cm'] self.logger.info(str(cm)) self.confusion_matrix_timeofday.reset() if epoch is not None: new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) valset = get_dataset( args.dataset, split='val', mode='ms_val' if args.multi_scale_eval else 'fast_val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) if self.args.multi_scale_eval: kwargs['collate_fn'] = test_batchify_fn self.valloader = data.DataLoader(valset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model if args.norm_layer == 'bn': norm_layer = BatchNorm2d elif args.norm_layer == 'sync_bn': assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!" norm_layer = SyncBatchNorm else: raise ValueError('Invalid norm_layer {}'.format(args.norm_layer)) model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=norm_layer, base_size=args.base_size, crop_size=args.crop_size, multi_grid=True, multi_dilation=[2, 4, 8], only_pam=True, ) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationMultiLosses() self.model, self.optimizer = model, optimizer # using cuda if args.multi_gpu: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() else: self.model = self.model.cuda() self.criterion = self.criterion.cuda() self.single_device_model = self.model.module if self.args.multi_gpu else self.model # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] self.single_device_model.load_state_dict(checkpoint['state_dict']) if not args.ft and not (args.only_val or args.only_vis or args.only_infer): self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format( args.resume, checkpoint['epoch'], checkpoint['best_pred'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, 0.6) self.best_pred = 0.0 def save_ckpt(self, epoch, score): is_best = False if score >= self.best_pred: is_best = True self.best_pred = score utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.single_device_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def training(self, epoch): train_loss = 0.0 self.model.train() self.lr_scheduler.step() tbar = tqdm(self.trainloader, miniters=20) for i, (image, target) in enumerate(tbar): if not self.args.multi_gpu: image = image.cuda() target = target.cuda() self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) if self.args.multi_gpu: loss = self.criterion(outputs, target) else: loss = self.criterion(*(outputs + (target, ))) loss.backward() self.optimizer.step() train_loss += loss.item() ep_log = 'ep {}'.format(epoch + 1) lr_log = 'lr ' + '{:.6f}'.format( self.optimizer.param_groups[0]['lr']).rstrip('0') loss_log = 'loss {:.3f}'.format(train_loss / (i + 1)) tbar.set_description(', '.join([ep_log, lr_log, loss_log])) def validation(self, epoch): def _get_pred(batch_im): with torch.no_grad(): # metric.update also accepts list, so no need to gather results from multi gpus if self.args.multi_scale_eval: assert len(batch_im) <= torch.cuda.device_count( ), "Multi-scale testing only allows batch size <= number of GPUs" scattered_pred = self.ms_evaluator.parallel_forward( batch_im) else: outputs = self.model(batch_im) scattered_pred = [ out[0] for out in outputs ] if self.args.multi_gpu else [outputs[0]] return scattered_pred # Lazy creation if not hasattr(self, 'ms_evaluator'): self.ms_evaluator = MultiEvalModule(self.single_device_model, self.nclass, scales=self.args.eval_scales, crop=self.args.crop_eval) self.metric = utils.SegmentationMetric(self.nclass) self.model.eval() tbar = tqdm(self.valloader, desc='\r') for i, (batch_im, target) in enumerate(tbar): # No need to put target to GPU, since the metrics are calculated by numpy. # And no need to put data to GPU manually if we use data parallel. if not self.args.multi_gpu and not isinstance( batch_im, (list, tuple)): batch_im = batch_im.cuda() scattered_pred = _get_pred(batch_im) scattered_target = [] ind = 0 for p in scattered_pred: target_tmp = target[ind:ind + len(p)] # Multi-scale testing. In fact, len(target_tmp) == 1 if isinstance(target_tmp, (list, tuple)): assert len(target_tmp) == 1 target_tmp = torch.stack(target_tmp) scattered_target.append(target_tmp) ind += len(p) self.metric.update(scattered_target, scattered_pred) pixAcc, mIoU = self.metric.get() tbar.set_description('ep {}, pixAcc: {:.4f}, mIoU: {:.4f}'.format( epoch + 1, pixAcc, mIoU)) return self.metric.get() def visualize(self, epoch): if (self.args.dir_of_im_to_vis == 'None') and (self.args.im_list_file_to_vis == 'None'): return if not hasattr(self, 'vis_im_paths'): if self.args.dir_of_im_to_vis != 'None': print('=> Visualize Dir {}'.format(self.args.dir_of_im_to_vis)) im_paths = list( walkdir(self.args.dir_of_im_to_vis, exts=['.jpg', '.png'])) else: print('=> Visualize Image List {}'.format( self.args.im_list_file_to_vis)) im_paths = read_lines(self.args.im_list_file_to_vis) print('=> Save Dir {}'.format(self.args.vis_save_dir)) im_paths = sorted(im_paths) # np.random.RandomState(seed=1).shuffle(im_paths) self.vis_im_paths = im_paths[:self.args.max_num_vis] cfg = { 'save_path': os.path.join(self.args.vis_save_dir, 'vis_epoch{}.png'.format(epoch)), 'multi_scale': self.args.multi_scale_eval, 'crop': self.args.crop_eval, 'num_class': self.nclass, 'scales': self.args.eval_scales, 'base_size': self.args.base_size, } vis_im_list(self.single_device_model, self.vis_im_paths, cfg) def infer_and_save(self, infer_dir, infer_save_dir): print('=> Infer Dir {}'.format(infer_dir)) print('=> Save Dir {}'.format(infer_save_dir)) sub_im_paths = list( walkdir(infer_dir, exts=['.jpg', '.png'], sub_path=True)) im_paths = [os.path.join(infer_dir, p) for p in sub_im_paths] # NOTE: Don't save result as JPEG, since it causes aliasing. save_paths = [ os.path.join(infer_save_dir, p.replace('.jpg', '.png')) for p in sub_im_paths ] cfg = { 'multi_scale': self.args.multi_scale_eval, 'crop': self.args.crop_eval, 'num_class': self.nclass, 'scales': self.args.eval_scales, 'base_size': self.args.base_size, } infer_and_save_im_list(self.single_device_model, im_paths, save_paths, cfg)
class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=False, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, dilated=args.dilated, multi_grid=args.multi_grid, stride=args.stride, lateral=args.lateral, jpu=args.jpu, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm, base_size=args.base_size, crop_size=args.crop_size) # print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'jpu') and model.jpu: params_list.append({ 'params': model.jpu.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) class_balance_weight = 'None' if args.dataset == "pcontext60": class_balance_weight = torch.tensor([ 1.3225e-01, 2.0757e+00, 1.8146e+01, 5.5052e+00, 2.2060e+00, 2.8054e+01, 2.0566e+00, 1.8598e+00, 2.4027e+00, 9.3435e+00, 3.5990e+00, 2.7487e-01, 1.4216e+00, 2.4986e+00, 7.7258e-01, 4.9020e-01, 2.9067e+00, 1.2197e+00, 2.2744e+00, 2.0444e+01, 3.0057e+00, 1.8167e+01, 3.7405e+00, 5.6749e-01, 3.2631e+00, 1.5007e+00, 5.5519e-01, 1.0056e+01, 1.8952e+01, 2.6792e-01, 2.7479e-01, 1.8309e+00, 2.0428e+01, 1.4788e+01, 1.4908e+00, 1.9113e+00, 2.6166e+02, 2.3233e-01, 1.9096e+01, 6.7025e+00, 2.8756e+00, 6.8804e-01, 4.4140e+00, 2.5621e+00, 4.4409e+00, 4.3821e+00, 1.3774e+01, 1.9803e-01, 3.6944e+00, 1.0397e+00, 2.0601e+00, 5.5811e+00, 1.3242e+00, 3.0088e-01, 1.7344e+01, 2.1569e+00, 2.7216e-01, 5.8731e-01, 1.9956e+00, 4.4004e+00 ]) elif args.dataset == "ade20k": class_balance_weight = torch.tensor([ 0.0772, 0.0431, 0.0631, 0.0766, 0.1095, 0.1399, 0.1502, 0.1702, 0.2958, 0.3400, 0.3738, 0.3749, 0.4059, 0.4266, 0.4524, 0.5725, 0.6145, 0.6240, 0.6709, 0.6517, 0.6591, 0.6818, 0.9203, 0.9965, 1.0272, 1.0967, 1.1202, 1.2354, 1.2900, 1.5038, 1.5160, 1.5172, 1.5036, 2.0746, 2.1426, 2.3159, 2.2792, 2.6468, 2.8038, 2.8777, 2.9525, 2.9051, 3.1050, 3.1785, 3.3533, 3.5300, 3.6120, 3.7006, 3.6790, 3.8057, 3.7604, 3.8043, 3.6610, 3.8268, 4.0644, 4.2698, 4.0163, 4.0272, 4.1626, 4.3702, 4.3144, 4.3612, 4.4389, 4.5612, 5.1537, 4.7653, 4.8421, 4.6813, 5.1037, 5.0729, 5.2657, 5.6153, 5.8240, 5.5360, 5.6373, 6.6972, 6.4561, 6.9555, 7.9239, 7.3265, 7.7501, 7.7900, 8.0528, 8.5415, 8.1316, 8.6557, 9.0550, 9.0081, 9.3262, 9.1391, 9.7237, 9.3775, 9.4592, 9.7883, 10.6705, 10.2113, 10.5845, 10.9667, 10.8754, 10.8274, 11.6427, 11.0687, 10.8417, 11.0287, 12.2030, 12.8830, 12.5082, 13.0703, 13.8410, 12.3264, 12.9048, 12.9664, 12.3523, 13.9830, 13.8105, 14.0345, 15.0054, 13.9801, 14.1048, 13.9025, 13.6179, 17.0577, 15.8351, 17.7102, 17.3153, 19.4640, 17.7629, 19.9093, 16.9529, 19.3016, 17.6671, 19.4525, 20.0794, 18.3574, 19.1219, 19.5089, 19.2417, 20.2534, 20.0332, 21.7496, 21.5427, 20.3008, 21.1942, 22.7051, 23.3359, 22.4300, 20.9934, 26.9073, 31.7362, 30.0784 ]) elif args.dataset == "cocostuff": class_balance_weight = torch.tensor([ 4.8557e-02, 6.4709e-02, 3.9255e+00, 9.4797e-01, 1.2703e+00, 1.4151e+00, 7.9733e-01, 8.4903e-01, 1.0751e+00, 2.4001e+00, 8.9736e+00, 5.3036e+00, 6.0410e+00, 9.3285e+00, 1.5952e+00, 3.6090e+00, 9.8772e-01, 1.2319e+00, 1.9194e+00, 2.7624e+00, 2.0548e+00, 1.2058e+00, 3.6424e+00, 2.0789e+00, 1.7851e+00, 6.7138e+00, 2.1315e+00, 6.9813e+00, 1.2679e+02, 2.0357e+00, 2.2933e+01, 2.3198e+01, 1.7439e+01, 4.1294e+01, 7.8678e+00, 4.3444e+01, 6.7543e+01, 1.0066e+01, 6.7520e+00, 1.3174e+01, 3.3499e+00, 6.9737e+00, 2.1482e+00, 1.9428e+01, 1.3240e+01, 1.9218e+01, 7.6836e-01, 2.6041e+00, 6.1822e+00, 1.4070e+00, 4.4074e+00, 5.7792e+00, 1.0321e+01, 4.9922e+00, 6.7408e-01, 3.1554e+00, 1.5832e+00, 8.9685e-01, 1.1686e+00, 2.6487e+00, 6.5354e-01, 2.3801e-01, 1.9536e+00, 1.5862e+00, 1.7797e+00, 2.7385e+01, 1.2419e+01, 3.9287e+00, 7.8897e+00, 7.5737e+00, 1.9758e+00, 8.1962e+01, 3.6922e+00, 2.0039e+00, 2.7333e+00, 5.4717e+00, 3.9048e+00, 1.9184e+01, 2.2689e+00, 2.6091e+02, 4.7366e+01, 2.3844e+00, 8.3310e+00, 1.4857e+01, 6.5076e+00, 2.0854e-01, 1.0425e+00, 1.7386e+00, 1.1973e+01, 5.2862e+00, 1.7341e+00, 8.6124e-01, 9.3702e+00, 2.8545e+00, 6.0123e+00, 1.7560e-01, 1.8128e+00, 1.3784e+00, 1.3699e+00, 2.3728e+00, 6.2819e-01, 1.3097e+00, 4.7892e-01, 1.0268e+01, 1.2307e+00, 5.5662e+00, 1.2867e+00, 1.2745e+00, 4.7505e+00, 8.4029e+00, 1.8679e+00, 1.0519e+01, 1.1240e+00, 1.4975e-01, 2.3146e+00, 4.1265e-01, 2.5896e+00, 1.4537e+00, 4.5575e+00, 7.8143e+00, 1.4603e+01, 2.8812e+00, 1.8868e+00, 7.8131e+01, 1.9323e+00, 7.4980e+00, 1.2446e+01, 2.1856e+00, 3.0973e+00, 4.1270e-01, 4.9016e+01, 7.1001e-01, 7.4035e+00, 2.3395e+00, 2.9207e-01, 2.4156e+00, 3.3211e+00, 2.1300e+00, 2.4533e-01, 1.7081e+00, 4.6621e+00, 2.9199e+00, 1.0407e+01, 7.6207e-01, 2.7806e-01, 3.7711e+00, 1.1852e-01, 8.8280e+00, 3.1700e-01, 6.3765e+01, 6.6032e+00, 5.2177e+00, 4.3596e+00, 6.2965e-01, 1.0207e+00, 1.1731e+01, 2.3935e+00, 9.2767e+00, 1.1023e-01, 3.6947e+00, 1.3943e+00, 2.3407e+00, 1.2112e-01, 2.8518e+00, 2.8195e+00, 1.0078e+00, 1.6614e+00, 6.5307e-01, 1.9070e+01, 2.7231e+00, 6.0769e-01 ]) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, weight=class_balance_weight) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best, filename='checkpoint_{}.pth.tar'.format(epoch)) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': new_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset trainset = get_segmentation_dataset(args.dataset, split='train', transform=input_transform) testset = get_segmentation_dataset(args.dataset, split='val', transform=input_transform) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=BatchNorm2d) #print(model) teacher_model = get_segmentation_model('encnet', dataset=args.dataset, backbone='resnet50', aux=True, se_loss=True, norm_layer=BatchNorm2d) #print(teacher_model) checkpoint = torch.load(args.resume_teacher) teacher_model.load_state_dict(checkpoint) self.teacher_model = teacher_model self.teacher_model.eval() # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) self.criterion_kd = KDLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) #self.criterion_kd = torch.nn.L1Loss() self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.teacher_model = DataParallelModel(self.teacher_model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() self.criterion_kd = DataParallelCriterionKD( self.criterion_kd).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) self.best_pred = 0.0 def training(self, epoch): train_loss = 0.0 teacher_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) with torch.no_grad(): teacher_outputs = self.teacher_model(image) teacher_targets = [] for teacher_output in teacher_outputs: pred1, se_pred, pred2 = tuple(teacher_output) teacher_targets.append(pred1) teacher_target = torch.cat(tuple(teacher_targets), 0) teacher_target = teacher_target.detach() loss_seg = 0 loss_seg = self.criterion(outputs, target) loss_seg.backward(retain_graph=True) train_loss += loss_seg.item() #loss_kd = self.criterion_kd(outputs, teacher_target) loss_kd = self.criterion_kd(outputs, teacher_target) loss_kd.backward() teacher_loss += loss_kd.item() loss = loss_seg + loss_kd #loss.backward() self.optimizer.step() tbar.set_description('Train loss: %.3f, Teacher loss: %.3f' % (train_loss / (i + 1), teacher_loss / (i + 1))) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best)
def test(args): # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset if args.eval: # set split='val' for validation set testing testset = get_edge_dataset(args.dataset, split='val', mode='testval', transform=input_transform, crop_size=args.crop_size) else: # set split='vis' for visulization testset = get_edge_dataset(args.dataset, split='vis', mode='vis', transform=input_transform, crop_size=args.crop_size) # output folder if args.eval: outdir_list_side5 = [] outdir_list_fuse = [] for i in range(testset.num_class): outdir_side5 = '%s/%s/%s_val/side5/class_%03d' % ( args.dataset, args.model, args.checkname, i + 1) if not os.path.exists(outdir_side5): os.makedirs(outdir_side5) outdir_list_side5.append(outdir_side5) outdir_fuse = '%s/%s/%s_val/fuse/class_%03d' % ( args.dataset, args.model, args.checkname, i + 1) if not os.path.exists(outdir_fuse): os.makedirs(outdir_fuse) outdir_list_fuse.append(outdir_fuse) else: outdir = '%s/%s/%s_vis' % (args.dataset, args.model, args.checkname) if not os.path.exists(outdir): os.makedirs(outdir) # dataloader loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} test_data = data.DataLoader(testset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, collate_fn=test_batchify_fn, **loader_kwargs) model = get_edge_model( args.model, dataset=args.dataset, backbone=args.backbone, norm_layer=BatchNorm2d, crop_size=args.crop_size, ) # resuming checkpoint if args.resume is None or not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) # strict=False, so that it is compatible with old pytorch saved models model.load_state_dict(checkpoint['state_dict'], strict=False) if args.cuda: model = DataParallelModel(model).cuda() print(model) model.eval() tbar = tqdm(test_data) if args.eval: for i, (images, im_paths, im_sizes) in enumerate(tbar): with torch.no_grad(): images = [image.unsqueeze(0) for image in images] images = torch.cat(images, 0) outputs = model(images.float()) num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) if num_gpus == 1: outputs = [outputs] # extract the side5 output and fuse output from outputs side5_list = [] fuse_list = [] for i in range(len(outputs)): #iterate for n (gpu counts) im_size = tuple(im_sizes[i].numpy()) output = outputs[i] side5 = output[0].squeeze_() side5 = side5.sigmoid_().cpu().numpy() side5 = side5[:, 0:im_size[1], 0:im_size[0]] fuse = output[1].squeeze_() fuse = fuse.sigmoid_().cpu().numpy() fuse = fuse[:, 0:im_size[1], 0:im_size[0]] side5_list.append(side5) fuse_list.append(fuse) for predict, impath in zip(side5_list, im_paths): for i in range(predict.shape[0]): predict_c = predict[i] path = os.path.join(outdir_list_side5[i], impath) io.imsave(path, predict_c) for predict, impath in zip(fuse_list, im_paths): for i in range(predict.shape[0]): predict_c = predict[i] path = os.path.join(outdir_list_fuse[i], impath) io.imsave(path, predict_c) else: for i, (images, masks, im_paths, im_sizes) in enumerate(tbar): with torch.no_grad(): images = [image.unsqueeze(0) for image in images] images = torch.cat(images, 0) outputs = model(images.float()) num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) if num_gpus == 1: outputs = [outputs] # extract the side5 output and fuse output from outputs side5_list = [] fuse_list = [] for i in range(len(outputs)): #iterate for n (gpu counts) im_size = tuple(im_sizes[i].numpy()) output = outputs[i] side5 = output[0].squeeze_() side5 = side5.sigmoid_().cpu().numpy() side5 = side5[:, 0:im_size[1], 0:im_size[0]] fuse = output[1].squeeze_() fuse = fuse.sigmoid_().cpu().numpy() fuse = fuse[:, 0:im_size[1], 0:im_size[0]] side5_list.append(side5) fuse_list.append(fuse) # visualize ground truth for gt, impath in zip(masks, im_paths): outname = os.path.splitext(impath)[0] + '_gt.png' path = os.path.join(outdir, outname) visualize_prediction(args.dataset, path, gt) # visualize side5 output for predict, impath in zip(side5_list, im_paths): outname = os.path.splitext(impath)[0] + '_side5.png' path = os.path.join(outdir, outname) visualize_prediction(args.dataset, path, predict) # visualize fuse output for predict, impath in zip(fuse_list, im_paths): outname = os.path.splitext(impath)[0] + '_fuse.png' path = os.path.join(outdir, outname) visualize_prediction(args.dataset, path, predict)
class Trainer(): def __init__(self, args): self.args = args if not self.args.tblogger: self.tblogger = SummaryWriter('./tensorboardX/') # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) print('trainset:%d' % len(trainset)) print('testset:%d' % len(testset)) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, dilated=args.dilated, lateral=args.lateral, jpu=args.jpu, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm, base_size=args.base_size, crop_size=args.crop_size) print(model) # model.apply(inplace_relu) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'jpu'): params_list.append({ 'params': model.jpu.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight) self.model, self.optimizer = model, optimizer # self.model, self.optimizer = amp.initialize(self.model.cuda(), self.optimizer, opt_level="O1") # using cuda if args.cuda: # self.model = torch.nn.parallel.DistributedDataParallel(self.model,device_ids=[0,1]) # self.model = DDP(self.model) # self.criterion = torch.nn.parallel.DistributedDataParallel(self.criterion, find_unused_parameters=True) self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # self.model = self.model.cuda() # self.criterion = self.criterion.cuda() # resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) self.total_loss = 0 def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.trainloader) for i, (image, target) in enumerate(tbar): self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) # with amp.scale_loss(loss, self.optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) if not self.args.tblogger and i % 100 == 0: self.tblogger.add_scalar('Train loss', (train_loss / (i + 1)), i + 1) if self.args.no_val: # save checkpoint every epoch is_best = False utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best, filename='checkpoint_{}.pth.tar'.format(epoch)) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) outputs = gather(outputs, 0, dim=0) pred = outputs[0] target = target.cuda() correct, labeled = utils.batch_pix_accuracy(pred.data, target) inter, union = utils.batch_intersection_union( pred.data, target, self.nclass) return correct, labeled, inter, union is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 tbar = tqdm(self.valloader, desc='\r') for i, (image, target) in enumerate(tbar): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): correct, labeled, inter, union = eval_batch( self.model, image, target) total_correct += correct total_label += labeled total_inter += inter total_union += union pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) IoU = 1.0 * total_inter / (np.spacing(1) + total_union) mIoU = IoU.mean() tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) # new_pred = (pixAcc + mIoU)/2 new_pred = mIoU if new_pred > self.best_pred: is_best = True self.best_pred = new_pred utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': new_pred, }, self.args, is_best)
class Trainer(): def __init__(self, args): self.args = args self.args.start_epoch = 0 self.args.cuda = True # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.490, .490, .490], [.247, .247, .247]) ]) # TODO: change mean and std # dataset train_chain = Compose([ HorizontalFlip(p=0.5), OneOf([ ElasticTransform( alpha=300, sigma=300 * 0.05, alpha_affine=300 * 0.03), GridDistortion(), OpticalDistortion(distort_limit=2, shift_limit=0.5), ], p=0.3), RandomSizedCrop( min_max_height=(900, 1024), height=1024, width=1024, p=0.5), ShiftScaleRotate(rotate_limit=20, p=0.5), Resize(self.args.size, self.args.size) ], p=1) val_chain = Compose([Resize(self.args.size, self.args.size)], p=1) num_fold = self.args.num_fold df_train = pd.read_csv(os.path.join(args.imagelist_path, 'train.csv')) df_val = pd.read_csv(os.path.join(args.imagelist_path, 'val.csv')) df_full = pd.concat((df_train, df_val), ignore_index=True, axis=0) df_full['lbl'] = (df_full['mask_name'].astype(str) == '-1').astype(int) skf = StratifiedKFold(8, shuffle=True, random_state=777) train_ids, val_ids = list( skf.split(df_full['mask_name'], df_full['lbl']))[num_fold] df_test = pd.read_csv( os.path.join(args.imagelist_path, 'test_true.csv')) df_new_train = pd.concat((df_full.iloc[train_ids], df_test), ignore_index=True, axis=0, sort=False) df_new_val = df_full.iloc[val_ids] df_new_train.to_csv(f'/tmp/train_new_pneumo_{num_fold}.csv') df_new_val.to_csv(f'/tmp/val_new_pneumo_{num_fold}.csv') trainset = SegmentationDataset(f'/tmp/train_new_pneumo_{num_fold}.csv', args.image_path, args.masks_path, input_transform=input_transform, transform_chain=train_chain, base_size=1024) testset = SegmentationDataset(f'/tmp/val_new_pneumo_{num_fold}.csv', args.image_path, args.masks_path, input_transform=input_transform, transform_chain=val_chain, base_size=1024) imgs = trainset.mask_img_map[:, [0, 3]] weights = make_weights_for_balanced_classes(imgs, 2) weights = torch.DoubleTensor(weights) train_sampler = (torch.utils.data.sampler.WeightedRandomSampler( weights, len(weights))) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader( trainset, batch_size=args.batch_size, drop_last=True, sampler=train_sampler, #shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = 1 if self.args.model == 'unet': model = UNet(n_classes=self.nclass, norm_layer=SyncBatchNorm) params_list = [ { 'params': model.parameters(), 'lr': args.lr }, ] elif self.args.model == 'encnet': model = EncNet( nclass=self.nclass, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm #nn.BatchNorm2d ) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) print(model) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=0.9, weight_decay=args.wd) # criterions if self.nclass == 1: self.criterion = SegmentationLossesBCE(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, use_dice=args.use_dice) else: self.criterion = SegmentationLosses( se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, ) self.model, self.optimizer = model, optimizer self.best_pred = 0.0 self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) #, map_location='cpu') self.args.start_epoch = checkpoint['epoch'] state_dict = {k: v for k, v in checkpoint['state_dict'].items()} self.model.load_state_dict(state_dict) self.optimizer.load_state_dict(checkpoint['optimizer']) for g in self.optimizer.param_groups: g['lr'] = args.lr self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) print(f'Best dice: {checkpoint["best_pred"]}') print(f'LR: {get_lr(self.optimizer):.5f}') self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.8, patience=4, threshold=0.001, threshold_mode='abs', min_lr=0.00001) self.logger = Logger(args.logger_dir) self.step_train = 0 self.best_loss = 20 self.step_val = 0 def logging(self, loss, running_acc, total, is_train, step, is_per_epoch, inputs=None, pred_masks=None, true_masks=None): #============ TensorBoard logging ============# # Log the scalar values accuracy = 100.0 * running_acc / total loss_str = 'Loss per epoch' if is_per_epoch else 'Loss per step' accuracy_str = 'Accuracy per epoch' if is_per_epoch else 'Accuracy per step' if is_per_epoch: loss = loss / len(self.trainloader) if is_train else loss / len( self.valloader) info = {loss_str: loss, accuracy_str: accuracy} for tag, value in info.items(): self.logger.scalar_summary(tag, value, step, is_train) # Log values and gradients of the parameters (histogram) for tag, value in filter(lambda p: p[1].requires_grad, self.model.named_parameters()): tag = tag.replace('.', '/') self.logger.histo_summary(tag, to_np(value), step, 1000, is_train) if value.grad is not None: self.logger.histo_summary(tag + '/grad', to_np(value.grad), step, 1000, is_train) if inputs is not None: # Log the images inputs = to_np(inputs)[:10].transpose(0, 2, 3, 1) for i in range(inputs.shape[0]): inputs[i] *= np.array([.247, .247, .247]) inputs[i] += np.array([.490, .490, .490]) inputs = (255 * inputs).astype(np.uint8) #.transpose(0, 3, 1, 2) pred_masks = (255 * pred_masks)[:10].astype(np.uint8) true_masks = (255 * true_masks)[:10].astype(np.uint8) inputs[..., 0] = (0.5 * inputs[..., 0] + 0.5 * pred_masks) inputs[..., 1] = (0.5 * inputs[..., 1] + 0.5 * true_masks).astype( np.uint8) info = { 'images': inputs, } for tag, inputs in info.items(): self.logger.image_summary(tag, inputs, step, is_train) def training(self, epoch): self.model.train() # tbar = tqdm(self.trainloader) total_score = 0 total_score_simple = 0 total_count = 0 total_loss = 0 for i, (_, image, target) in enumerate(self.trainloader): # if i >= 1000: # break start_t = dtm.now() torch.cuda.empty_cache() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) loss = self.criterion(outputs, target) # loss = loss / 4 loss.backward() # if (i+1) % 4 == 0: self.optimizer.step() # self.model.zero_grad() self.optimizer.zero_grad() total_loss += loss.item() preds_ten = [v[0].data.cpu() for v in outputs] preds_ten = torch.cat(preds_ten) if self.args.model == 'encnet': cls_preds_ten = [v[1].data.cpu() for v in outputs] cls_preds_ten = torch.cat(cls_preds_ten) cls_mask = torch.sigmoid(cls_preds_ten).numpy().reshape( -1) < 0.5 preds = torch.sigmoid(preds_ten).numpy()[:, 0, :, :] elif self.args.model == 'unet': cls_mask = np.zeros(preds_ten.size(0)) preds = preds_ten.numpy()[:, 0, :, :] trues = target.numpy() local_score = dice_loss(trues, preds, cls_mask=cls_mask) local_score_simple = dice_loss(trues, preds) batch_size = preds.shape[0] total_score += local_score total_score_simple += local_score_simple total_count += batch_size print(( f'Epoch: {epoch}, Batch: {i + 1} / {len(self.trainloader)}, ' #{len(self.trainloader)}, ' f'loss: {total_loss / (i + 1):.3f}, batch loss: {loss.item():.3f}' f', batch simple DICE: {local_score_simple / batch_size:.3f}' f', total simple DICE: {total_score_simple / total_count:.3f}' f', batch DICE: {local_score / batch_size:.3f}' f', total DICE: {total_score / total_count:.5f}' f', lr: {get_lr(self.optimizer):.5f}' f', time: {(dtm.now() - start_t).total_seconds():.2f}')) # if i > 5: # break self.step_train += 1 if i % 10 == 0: sys.stdout.flush() pred_masks = np.array( [preds[i] * cls_mask[i] for i in range(len(cls_mask))]) pred_masks = (pred_masks > 0.5).astype(int) self.logging(loss.item(), total_score, total_count, is_train=True, step=self.step_train, is_per_epoch=False, inputs=image, pred_masks=pred_masks, true_masks=trues) def validation(self, epoch): # Fast test during the training def eval_batch(model, image, target): outputs = model(image) loss = self.criterion(outputs, target) preds_ten = [v[0].data.cpu() for v in outputs] preds_ten = torch.cat(preds_ten) if self.args.model == 'encnet': cls_preds_ten = [v[1].data.cpu() for v in outputs] cls_preds_ten = torch.cat(cls_preds_ten) cls_mask = torch.sigmoid(cls_preds_ten).numpy().reshape( -1) < 0.5 preds = torch.sigmoid(preds_ten).numpy()[:, 0, :, :] elif self.args.model == 'unet': cls_mask = np.zeros(preds_ten.size(0)) preds = preds_ten.numpy()[:, 0, :, :] trues = target.numpy() local_score = dice_loss(trues, preds, cls_mask=cls_mask) batch_size = preds.shape[0] return preds, trues, cls_mask, local_score, batch_size, loss is_best = False self.model.eval() total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 total_score = 0 total_loss = 0 total_count = 0 for i, (_, image, target) in enumerate(self.valloader): if torch_ver == "0.3": image = Variable(image, volatile=True) correct, labeled, inter, union = eval_batch( self.model, image, target) else: with torch.no_grad(): preds, trues, cls_mask, local_score, batch_size, loss = ( eval_batch(self.model, image, target)) total_score += local_score total_loss += loss.item() total_count += batch_size dice = total_score / total_count print( f'val epoch: {epoch}, batch: {i + 1} / {len(self.valloader)}, DICE: {dice:.5f}' ) self.step_val += 1 # if i > 15: # break if i % 10 == 0: sys.stdout.flush() pred_masks = np.array( [preds[i] * cls_mask[i] for i in range(len(cls_mask))]) pred_masks = (pred_masks > 0.5).astype(int) self.logging(loss.item(), total_score, total_count, is_train=False, step=self.step_val, is_per_epoch=False, inputs=image, pred_masks=pred_masks, true_masks=trues) new_pred = dice if new_pred > self.best_pred: self.best_pred = new_pred torch.save( { 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, 'best_loss': self.best_loss }, self.args.ckpt_name[:-4] + '_best.pth') new_loss = total_loss / total_count if new_loss < self.best_loss: self.best_loss = new_loss torch.save( { 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, 'best_loss': self.best_loss }, self.args.ckpt_name[:-4] + '_best_loss.pth') print(f'Validation DICE: {dice:.5f}, loss: {new_loss:.5f}') print( f'Validation best DICE: {self.best_pred:.5f}, best loss: {self.best_loss:.5f}' ) torch.save( { 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': dice, }, self.args.ckpt_name[:-4] + '_last.pth') return new_loss def __del__(self): del self.model gc.collect()