def __init__(self, args): if args.se_loss: args.checkname = args.checkname + "_se" self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225])]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode ='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': False} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone = args.backbone, aux = args.aux, se_loss = args.se_loss, norm_layer = BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size) print(model) # count parameter number pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Total number of parameters: %d"%pytorch_total_params) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},] if hasattr(model, 'head'): if args.diflr: params_list.append({'params': model.head.parameters(), 'lr': args.lr*10}) else: params_list.append({'params': model.head.parameters(), 'lr': args.lr}) if hasattr(model, 'auxlayer'): if args.diflr: params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10}) else: params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer = torch.optim.ASGD(params_list, # lr=args.lr, # weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None and len(args.resume)>0: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: # load weights for the same model # self.model.module.load_state_dict(checkpoint['state_dict']) # model and checkpoint have different strucutures pretrained_dict = checkpoint['state_dict'] model_dict = self.model.module.state_dict() for name, param in pretrained_dict.items(): if name not in model_dict: continue if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data model_dict[name].copy_(param) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader),lr_step=args.lr_step) self.best_pred = 0.0
def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, dilated=args.dilated, lateral=args.lateral, jpu=args.jpu, aux=args.aux, se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d, ## BatchNorm2d base_size=args.base_size, crop_size=args.crop_size) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'jpu'): params_list.append({ 'params': model.jpu.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader)) self.best_pred = 0.0
class Trainer(): def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) valset = get_dataset( args.dataset, split='val', mode='ms_val' if args.multi_scale_eval else 'fast_val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) if self.args.multi_scale_eval: kwargs['collate_fn'] = test_batchify_fn self.valloader = data.DataLoader(valset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model if args.norm_layer == 'bn': norm_layer = BatchNorm2d elif args.norm_layer == 'sync_bn': assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!" norm_layer = SyncBatchNorm else: raise ValueError('Invalid norm_layer {}'.format(args.norm_layer)) model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=norm_layer, base_size=args.base_size, crop_size=args.crop_size, multi_grid=True, multi_dilation=[2, 4, 8], only_pam=True, ) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationMultiLosses() self.model, self.optimizer = model, optimizer # using cuda if args.multi_gpu: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() else: self.model = self.model.cuda() self.criterion = self.criterion.cuda() self.single_device_model = self.model.module if self.args.multi_gpu else self.model # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] self.single_device_model.load_state_dict(checkpoint['state_dict']) if not args.ft and not (args.only_val or args.only_vis or args.only_infer): self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format( args.resume, checkpoint['epoch'], checkpoint['best_pred'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, 0.6) self.best_pred = 0.0 def save_ckpt(self, epoch, score): is_best = False if score >= self.best_pred: is_best = True self.best_pred = score utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.single_device_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, self.args, is_best) def training(self, epoch): train_loss = 0.0 self.model.train() self.lr_scheduler.step() tbar = tqdm(self.trainloader, miniters=20) for i, (image, target) in enumerate(tbar): if not self.args.multi_gpu: image = image.cuda() target = target.cuda() self.optimizer.zero_grad() if torch_ver == "0.3": image = Variable(image) target = Variable(target) outputs = self.model(image) if self.args.multi_gpu: loss = self.criterion(outputs, target) else: loss = self.criterion(*(outputs + (target, ))) loss.backward() self.optimizer.step() train_loss += loss.item() ep_log = 'ep {}'.format(epoch + 1) lr_log = 'lr ' + '{:.6f}'.format( self.optimizer.param_groups[0]['lr']).rstrip('0') loss_log = 'loss {:.3f}'.format(train_loss / (i + 1)) tbar.set_description(', '.join([ep_log, lr_log, loss_log])) def validation(self, epoch): def _get_pred(batch_im): with torch.no_grad(): # metric.update also accepts list, so no need to gather results from multi gpus if self.args.multi_scale_eval: assert len(batch_im) <= torch.cuda.device_count( ), "Multi-scale testing only allows batch size <= number of GPUs" scattered_pred = self.ms_evaluator.parallel_forward( batch_im) else: outputs = self.model(batch_im) scattered_pred = [ out[0] for out in outputs ] if self.args.multi_gpu else [outputs[0]] return scattered_pred # Lazy creation if not hasattr(self, 'ms_evaluator'): self.ms_evaluator = MultiEvalModule(self.single_device_model, self.nclass, scales=self.args.eval_scales, crop=self.args.crop_eval) self.metric = utils.SegmentationMetric(self.nclass) self.model.eval() tbar = tqdm(self.valloader, desc='\r') for i, (batch_im, target) in enumerate(tbar): # No need to put target to GPU, since the metrics are calculated by numpy. # And no need to put data to GPU manually if we use data parallel. if not self.args.multi_gpu and not isinstance( batch_im, (list, tuple)): batch_im = batch_im.cuda() scattered_pred = _get_pred(batch_im) scattered_target = [] ind = 0 for p in scattered_pred: target_tmp = target[ind:ind + len(p)] # Multi-scale testing. In fact, len(target_tmp) == 1 if isinstance(target_tmp, (list, tuple)): assert len(target_tmp) == 1 target_tmp = torch.stack(target_tmp) scattered_target.append(target_tmp) ind += len(p) self.metric.update(scattered_target, scattered_pred) pixAcc, mIoU = self.metric.get() tbar.set_description('ep {}, pixAcc: {:.4f}, mIoU: {:.4f}'.format( epoch + 1, pixAcc, mIoU)) return self.metric.get() def visualize(self, epoch): if (self.args.dir_of_im_to_vis == 'None') and (self.args.im_list_file_to_vis == 'None'): return if not hasattr(self, 'vis_im_paths'): if self.args.dir_of_im_to_vis != 'None': print('=> Visualize Dir {}'.format(self.args.dir_of_im_to_vis)) im_paths = list( walkdir(self.args.dir_of_im_to_vis, exts=['.jpg', '.png'])) else: print('=> Visualize Image List {}'.format( self.args.im_list_file_to_vis)) im_paths = read_lines(self.args.im_list_file_to_vis) print('=> Save Dir {}'.format(self.args.vis_save_dir)) im_paths = sorted(im_paths) # np.random.RandomState(seed=1).shuffle(im_paths) self.vis_im_paths = im_paths[:self.args.max_num_vis] cfg = { 'save_path': os.path.join(self.args.vis_save_dir, 'vis_epoch{}.png'.format(epoch)), 'multi_scale': self.args.multi_scale_eval, 'crop': self.args.crop_eval, 'num_class': self.nclass, 'scales': self.args.eval_scales, 'base_size': self.args.base_size, } vis_im_list(self.single_device_model, self.vis_im_paths, cfg) def infer_and_save(self, infer_dir, infer_save_dir): print('=> Infer Dir {}'.format(infer_dir)) print('=> Save Dir {}'.format(infer_save_dir)) sub_im_paths = list( walkdir(infer_dir, exts=['.jpg', '.png'], sub_path=True)) im_paths = [os.path.join(infer_dir, p) for p in sub_im_paths] # NOTE: Don't save result as JPEG, since it causes aliasing. save_paths = [ os.path.join(infer_save_dir, p.replace('.jpg', '.png')) for p in sub_im_paths ] cfg = { 'multi_scale': self.args.multi_scale_eval, 'crop': self.args.crop_eval, 'num_class': self.nclass, 'scales': self.args.eval_scales, 'base_size': self.args.base_size, } infer_and_save_im_list(self.single_device_model, im_paths, save_paths, cfg)
def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) valset = get_dataset( args.dataset, split='val', mode='ms_val' if args.multi_scale_eval else 'fast_val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) if self.args.multi_scale_eval: kwargs['collate_fn'] = test_batchify_fn self.valloader = data.DataLoader(valset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model if args.norm_layer == 'bn': norm_layer = BatchNorm2d elif args.norm_layer == 'sync_bn': assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!" norm_layer = SyncBatchNorm else: raise ValueError('Invalid norm_layer {}'.format(args.norm_layer)) model = get_segmentation_model( args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=norm_layer, base_size=args.base_size, crop_size=args.crop_size, multi_grid=True, multi_dilation=[2, 4, 8], only_pam=True, ) print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # criterions self.criterion = SegmentationMultiLosses() self.model, self.optimizer = model, optimizer # using cuda if args.multi_gpu: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() else: self.model = self.model.cuda() self.criterion = self.criterion.cuda() self.single_device_model = self.model.module if self.args.multi_gpu else self.model # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] self.single_device_model.load_state_dict(checkpoint['state_dict']) if not args.ft and not (args.only_val or args.only_vis or args.only_infer): self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format( args.resume, checkpoint['epoch'], checkpoint['best_pred'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, 0.6) self.best_pred = 0.0
def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) #print(model) self.logger.info(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) cityscape_weight = torch.FloatTensor([ 0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507 ]) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #weight for class imbalance # self.criterion = SegmentationMultiLosses(nclass=self.nclass, weight=cityscape_weight) self.criterion = SegmentationMultiLosses(nclass=self.nclass) #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0
def __init__(self, args): self.args = args args.log_name = str(args.checkname) args.log_root = os.path.join(args.dataset, args.log_root) # dataset/log/ self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225])]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale} trainset = get_segmentation_dataset(args.dataset, split='trainval', mode='trainval', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', # crop fixed size as model input **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, norm_layer=BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size, ) #print(model) self.logger.info(model) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},] if hasattr(model, 'head'): print("this model has object, head") params_list.append({'params': model.head.parameters(), 'lr': args.lr*10}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.criterion = SegmentationLosses(nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() if '5' in args.data_dir: dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) else: dataset = CityscapesDataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) loader = data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) val_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, # joint_transform=joint_transform, transform=input_transform, target_transform=target_transform ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear', align_corners=True) net = PSP( nclass = args.n_classes, backbone='resnet101', root=args.model_path_prefix, norm_layer=BatchNorm2d, ) params_list = [ {'params': net.pretrained.parameters(), 'lr': args.learning_rate}, {'params': net.head.parameters(), 'lr': args.learning_rate*10}, {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10}, ] optimizer = torch.optim.SGD(params_list, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255) # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255) net = DataParallelModel(net).cuda() criterion = DataParallelCriterion(criterion).cuda() logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train') scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate, args.num_epoch, len(loader), logger=logger, lr_step=args.lr_step) net_eval = Eval(net) num_batches = len(loader) best_pred = 0.0 for epoch in range(args.num_epoch): loss_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() tem_time = time.time() for batch_index, batch_data in enumerate(loader): scheduler(optimizer, batch_index, epoch, best_pred) show_fig = (batch_index+1) % args.show_img_freq == 0 iteration = batch_index+1+epoch*num_batches net.train() img, label, name = batch_data img = img.cuda() label_cuda = label.cuda() data_time_rec.update(time.time()-tem_time) output = net(img) loss = criterion(output, label_cuda) optimizer.zero_grad() loss.backward() optimizer.step() loss_rec.update(loss.item()) writer.add_scalar('A_seg_loss', loss.item(), iteration) batch_time_rec.update(time.time()-tem_time) tem_time = time.time() if (batch_index+1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' f'Loss: {loss_rec.avg:.2f}' ) # if show_fig: # # base_lr = optimizer.param_groups[0]["lr"] # output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu() # # fig, axes = plt.subplots(2, 1, figsize=(12, 14)) # # axes = axes.flat # # axes[0].imshow(colorize_mask(output.numpy())) # # axes[0].set_title(name[0]) # # axes[1].imshow(colorize_mask(label[0, ...].numpy())) # # axes[1].set_title(f'seg_true_{base_lr:.6f}') # # writer.add_figure('A_seg', fig, iteration) # output_mask = np.asarray(colorize_mask(output.numpy())) # label = np.asarray(colorize_mask(label[0,...].numpy())) # image_out = np.concatenate([output_mask, label]) # writer.add_image('A_seg', image_out, iteration) mean_iu = test_miou(net_eval, val_loader, upsample, './style_seg/dataset/info.json') torch.save( net.module.state_dict(), os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth') ) writer.close()
def __init__(self, args): self.args = args # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=False, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, dilated=args.dilated, multi_grid=args.multi_grid, stride=args.stride, lateral=args.lateral, jpu=args.jpu, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm, base_size=args.base_size, crop_size=args.crop_size) # print(model) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'jpu') and model.jpu: params_list.append({ 'params': model.jpu.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) class_balance_weight = 'None' if args.dataset == "pcontext60": class_balance_weight = torch.tensor([ 1.3225e-01, 2.0757e+00, 1.8146e+01, 5.5052e+00, 2.2060e+00, 2.8054e+01, 2.0566e+00, 1.8598e+00, 2.4027e+00, 9.3435e+00, 3.5990e+00, 2.7487e-01, 1.4216e+00, 2.4986e+00, 7.7258e-01, 4.9020e-01, 2.9067e+00, 1.2197e+00, 2.2744e+00, 2.0444e+01, 3.0057e+00, 1.8167e+01, 3.7405e+00, 5.6749e-01, 3.2631e+00, 1.5007e+00, 5.5519e-01, 1.0056e+01, 1.8952e+01, 2.6792e-01, 2.7479e-01, 1.8309e+00, 2.0428e+01, 1.4788e+01, 1.4908e+00, 1.9113e+00, 2.6166e+02, 2.3233e-01, 1.9096e+01, 6.7025e+00, 2.8756e+00, 6.8804e-01, 4.4140e+00, 2.5621e+00, 4.4409e+00, 4.3821e+00, 1.3774e+01, 1.9803e-01, 3.6944e+00, 1.0397e+00, 2.0601e+00, 5.5811e+00, 1.3242e+00, 3.0088e-01, 1.7344e+01, 2.1569e+00, 2.7216e-01, 5.8731e-01, 1.9956e+00, 4.4004e+00 ]) elif args.dataset == "ade20k": class_balance_weight = torch.tensor([ 0.0772, 0.0431, 0.0631, 0.0766, 0.1095, 0.1399, 0.1502, 0.1702, 0.2958, 0.3400, 0.3738, 0.3749, 0.4059, 0.4266, 0.4524, 0.5725, 0.6145, 0.6240, 0.6709, 0.6517, 0.6591, 0.6818, 0.9203, 0.9965, 1.0272, 1.0967, 1.1202, 1.2354, 1.2900, 1.5038, 1.5160, 1.5172, 1.5036, 2.0746, 2.1426, 2.3159, 2.2792, 2.6468, 2.8038, 2.8777, 2.9525, 2.9051, 3.1050, 3.1785, 3.3533, 3.5300, 3.6120, 3.7006, 3.6790, 3.8057, 3.7604, 3.8043, 3.6610, 3.8268, 4.0644, 4.2698, 4.0163, 4.0272, 4.1626, 4.3702, 4.3144, 4.3612, 4.4389, 4.5612, 5.1537, 4.7653, 4.8421, 4.6813, 5.1037, 5.0729, 5.2657, 5.6153, 5.8240, 5.5360, 5.6373, 6.6972, 6.4561, 6.9555, 7.9239, 7.3265, 7.7501, 7.7900, 8.0528, 8.5415, 8.1316, 8.6557, 9.0550, 9.0081, 9.3262, 9.1391, 9.7237, 9.3775, 9.4592, 9.7883, 10.6705, 10.2113, 10.5845, 10.9667, 10.8754, 10.8274, 11.6427, 11.0687, 10.8417, 11.0287, 12.2030, 12.8830, 12.5082, 13.0703, 13.8410, 12.3264, 12.9048, 12.9664, 12.3523, 13.9830, 13.8105, 14.0345, 15.0054, 13.9801, 14.1048, 13.9025, 13.6179, 17.0577, 15.8351, 17.7102, 17.3153, 19.4640, 17.7629, 19.9093, 16.9529, 19.3016, 17.6671, 19.4525, 20.0794, 18.3574, 19.1219, 19.5089, 19.2417, 20.2534, 20.0332, 21.7496, 21.5427, 20.3008, 21.1942, 22.7051, 23.3359, 22.4300, 20.9934, 26.9073, 31.7362, 30.0784 ]) elif args.dataset == "cocostuff": class_balance_weight = torch.tensor([ 4.8557e-02, 6.4709e-02, 3.9255e+00, 9.4797e-01, 1.2703e+00, 1.4151e+00, 7.9733e-01, 8.4903e-01, 1.0751e+00, 2.4001e+00, 8.9736e+00, 5.3036e+00, 6.0410e+00, 9.3285e+00, 1.5952e+00, 3.6090e+00, 9.8772e-01, 1.2319e+00, 1.9194e+00, 2.7624e+00, 2.0548e+00, 1.2058e+00, 3.6424e+00, 2.0789e+00, 1.7851e+00, 6.7138e+00, 2.1315e+00, 6.9813e+00, 1.2679e+02, 2.0357e+00, 2.2933e+01, 2.3198e+01, 1.7439e+01, 4.1294e+01, 7.8678e+00, 4.3444e+01, 6.7543e+01, 1.0066e+01, 6.7520e+00, 1.3174e+01, 3.3499e+00, 6.9737e+00, 2.1482e+00, 1.9428e+01, 1.3240e+01, 1.9218e+01, 7.6836e-01, 2.6041e+00, 6.1822e+00, 1.4070e+00, 4.4074e+00, 5.7792e+00, 1.0321e+01, 4.9922e+00, 6.7408e-01, 3.1554e+00, 1.5832e+00, 8.9685e-01, 1.1686e+00, 2.6487e+00, 6.5354e-01, 2.3801e-01, 1.9536e+00, 1.5862e+00, 1.7797e+00, 2.7385e+01, 1.2419e+01, 3.9287e+00, 7.8897e+00, 7.5737e+00, 1.9758e+00, 8.1962e+01, 3.6922e+00, 2.0039e+00, 2.7333e+00, 5.4717e+00, 3.9048e+00, 1.9184e+01, 2.2689e+00, 2.6091e+02, 4.7366e+01, 2.3844e+00, 8.3310e+00, 1.4857e+01, 6.5076e+00, 2.0854e-01, 1.0425e+00, 1.7386e+00, 1.1973e+01, 5.2862e+00, 1.7341e+00, 8.6124e-01, 9.3702e+00, 2.8545e+00, 6.0123e+00, 1.7560e-01, 1.8128e+00, 1.3784e+00, 1.3699e+00, 2.3728e+00, 6.2819e-01, 1.3097e+00, 4.7892e-01, 1.0268e+01, 1.2307e+00, 5.5662e+00, 1.2867e+00, 1.2745e+00, 4.7505e+00, 8.4029e+00, 1.8679e+00, 1.0519e+01, 1.1240e+00, 1.4975e-01, 2.3146e+00, 4.1265e-01, 2.5896e+00, 1.4537e+00, 4.5575e+00, 7.8143e+00, 1.4603e+01, 2.8812e+00, 1.8868e+00, 7.8131e+01, 1.9323e+00, 7.4980e+00, 1.2446e+01, 2.1856e+00, 3.0973e+00, 4.1270e-01, 4.9016e+01, 7.1001e-01, 7.4035e+00, 2.3395e+00, 2.9207e-01, 2.4156e+00, 3.3211e+00, 2.1300e+00, 2.4533e-01, 1.7081e+00, 4.6621e+00, 2.9199e+00, 1.0407e+01, 7.6207e-01, 2.7806e-01, 3.7711e+00, 1.1852e-01, 8.8280e+00, 3.1700e-01, 6.3765e+01, 6.6032e+00, 5.2177e+00, 4.3596e+00, 6.2965e-01, 1.0207e+00, 1.1731e+01, 2.3935e+00, 9.2767e+00, 1.1023e-01, 3.6947e+00, 1.3943e+00, 2.3407e+00, 1.2112e-01, 2.8518e+00, 2.8195e+00, 1.0078e+00, 1.6614e+00, 6.5307e-01, 1.9070e+01, 2.7231e+00, 6.0769e-01 ]) # criterions self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, weight=class_balance_weight) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader))
def train(args): weight_dir = args.log_root # os.path.join(args.log_root, 'weights') log_dir = os.path.join( args.log_root, 'logs', 'SS-Net-{}'.format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))) data_dir = os.path.join(args.data_root, args.dataset) # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 1. Setup DataLoader # +++++++++++++++++++++++++++++++++++++++++++++++++++ # print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #") print("> 0. Setting up DataLoader...") net_h, net_w = int(args.img_row * args.crop_ratio), int(args.img_col * args.crop_ratio) augment_train = Compose([ RandomHorizontallyFlip(), RandomSized((0.5, 0.75)), RandomRotate(5), RandomCrop((net_h, net_w)) ]) augment_valid = Compose([ RandomHorizontallyFlip(), Scale((args.img_row, args.img_col)), CenterCrop((net_h, net_w)) ]) train_loader = CityscapesLoader(data_dir, gt='gtFine', split='train', img_size=(args.img_row, args.img_col), is_transform=True, augmentations=augment_train) valid_loader = CityscapesLoader(data_dir, gt='gtFine', split='val', img_size=(args.img_row, args.img_col), is_transform=True, augmentations=augment_valid) num_classes = train_loader.n_classes tra_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=int(multiprocessing.cpu_count() / 2), shuffle=True) val_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=int(multiprocessing.cpu_count() / 2)) # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 2. Setup Model # +++++++++++++++++++++++++++++++++++++++++++++++++++ # print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #") print("> 1. Setting up Model...") model = RetinaNet(num_classes=num_classes, input_size=(net_h, net_w)) # model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda() model = DataParallelModel(model, device_ids=args.device_ids).cuda() # multi-gpu # 2.1 Setup Optimizer # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # Check if model has custom optimizer if hasattr(model.module, 'optimizer'): print('> Using custom optimizer') optimizer = model.module.optimizer else: optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.90, weight_decay=5e-4, nesterov=True) # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1) # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) # 2.2 Setup Loss # +++++++++++++++++++++++++++++++++++++++++++++++++++ # class_weight = np.array([ 0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437, 5.58540548, 3.56563995, 0.12704978, 1., 0.46783719, 1.34551528, 5.29974114, 0.28342531, 0.9396095, 0.81551811, 0.42679146, 3.6399074, 2.78376194 ], dtype=float) class_weight = torch.from_numpy(class_weight).float().cuda() sem_loss = bootstrapped_cross_entropy2d sem_loss = DataParallelCriterion(sem_loss, device_ids=args.device_ids) se_loss = SemanticEncodingLoss(num_classes=19, ignore_label=250, alpha=0.50).cuda() se_loss_parallel = DataParallelCriterion(se_loss, device_ids=args.device_ids) """ # multi-gpu bootstrapped_cross_entropy2d = ContextBootstrappedCELoss2D(num_classes=num_classes, ignore=250, kernel_size=5, padding=4, dilate=2, use_gpu=True) loss_sem = DataParallelCriterion(bootstrapped_cross_entropy2d, device_ids=[0, 1]) """ # 2.3 Setup Metrics # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # !!!!! Here Metrics !!!!! metrics = RunningScore(num_classes) # num_classes = 93 # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 3. Resume Model # +++++++++++++++++++++++++++++++++++++++++++++++++++ # print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #") print("> 2. Model state init or resume...") args.start_epoch = 1 args.start_iter = 0 beat_map = 0. if args.resume is not None: full_path = os.path.join(os.path.join(weight_dir, 'train_model'), args.resume) if os.path.isfile(full_path): print("> Loading model and optimizer from checkpoint '{}'".format( args.resume)) checkpoint = torch.load(full_path) args.start_epoch = checkpoint['epoch'] args.start_iter = checkpoint['iter'] beat_map = checkpoint['beat_map'] model.load_state_dict(checkpoint['model_state']) # weights optimizer.load_state_dict( checkpoint['optimizer_state']) # gradient state del checkpoint print("> Loaded checkpoint '{}' (epoch {}, iter {})".format( args.resume, args.start_epoch, args.start_iter)) else: print("> No checkpoint found at '{}'".format(full_path)) raise Exception("> No checkpoint found at '{}'".format(full_path)) else: # init_weights(model, pi=0.01, # pre_trained=os.path.join(args.log_root, 'resnet50_imagenet.pth')) if args.pre_trained is not None: print("> Loading weights from pre-trained model '{}'".format( args.pre_trained)) full_path = os.path.join(args.log_root, args.pre_trained) pre_weight = torch.load(full_path) prefix = "module.fpn.base_net." model_dict = model.state_dict() pretrained_dict = {(prefix + k): v for k, v in pre_weight.items() if (prefix + k) in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) del pre_weight del model_dict del pretrained_dict # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4. Train Model # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.0. Setup tensor-board for visualization # +++++++++++++++++++++++++++++++++++++++++++++++++++ # writer = None if args.tensor_board: writer = SummaryWriter(log_dir=log_dir, comment="SSnet_Cityscapes") # dummy_input = Variable(torch.rand(1, 3, args.img_row, args.img_col).cuda(), requires_grad=True) # writer.add_graph(model, dummy_input) print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #") print("> 3. Model Training start...") topk_init = 512 num_batches = int( math.ceil( len(tra_loader.dataset.files[tra_loader.dataset.split]) / float(tra_loader.batch_size))) # lr_period = 20 * num_batches for epoch in np.arange(args.start_epoch - 1, args.num_epochs): # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.1 Mini-Batch Training # +++++++++++++++++++++++++++++++++++++++++++++++++++ # model.train() topk_base = topk_init if epoch == args.start_epoch - 1: pbar = tqdm(np.arange(args.start_iter, num_batches)) start_iter = args.start_iter else: pbar = tqdm(np.arange(num_batches)) start_iter = 0 lr = args.learning_rate # lr = adjust_learning_rate(optimizer, init_lr=args.learning_rate, decay_rate=0.1, curr_epoch=epoch, # epoch_step=20, start_decay_at_epoch=args.start_decay_at_epoch, # total_epoch=args.num_epochs, mode='exp') # scheduler.step() # for train_i, (images, gt_masks) in enumerate(tra_loader): # One mini-Batch datasets, One iteration for train_i, (images, gt_masks) in zip(range(start_iter, num_batches), tra_loader): full_iter = (epoch * num_batches) + train_i + 1 lr = poly_lr_scheduler(optimizer, init_lr=args.learning_rate, iter=full_iter, lr_decay_iter=1, max_iter=args.num_epochs * num_batches, power=0.9) # lr = args.learning_rate * cosine_annealing_lr(lr_period, full_iter) # optimizer = set_optimizer_lr(optimizer, lr) images = images.cuda().requires_grad_() se_labels = se_loss.unique_encode(gt_masks) se_labels = se_labels.cuda() gt_masks = gt_masks.cuda() topk_base = poly_topk_scheduler(init_topk=topk_init, iter=full_iter, topk_decay_iter=1, max_iter=args.num_epochs * num_batches, power=0.95) optimizer.zero_grad() se, sem_seg_pred = model(images) # --------------------------------------------------- # # Compute loss # --------------------------------------------------- # topk = topk_base * 512 train_loss = sem_loss(input=sem_seg_pred, target=gt_masks, K=topk, weight=None) train_se_loss = se_loss_parallel(predicts=se, enc_cls_target=se_labels, size_average=True, reduction='elementwise_mean') loss = train_loss + args.alpha * train_se_loss loss.backward() # back-propagation torch.nn.utils.clip_grad_norm_(model.parameters(), 1e3) optimizer.step() # parameter update based on the current gradient pbar.update(1) pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.num_epochs)) pbar.set_postfix(Train_Loss=train_loss.item(), Train_SE_Loss=train_se_loss.item(), TopK=topk_base) # pbar.set_postfix(Train_Loss=train_loss.item(), TopK=topk_base) # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.1.1 Verbose training process # +++++++++++++++++++++++++++++++++++++++++++++++++++ # if (train_i + 1) % args.verbose_interval == 0: # ---------------------------------------- # # 1. Training Losses # ---------------------------------------- # loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f " % ( epoch + 1, args.num_epochs, train_i + 1, loss.item()) # ---------------------------------------- # # 2. Training Metrics # ---------------------------------------- # sem_seg_pred = F.softmax(sem_seg_pred, dim=1) pred = sem_seg_pred.data.max(1)[1].cpu().numpy() gt = gt_masks.data.cpu().numpy() metrics.update( gt, pred) # accumulate the metrics (confusion_matrix and ious) score, _ = metrics.get_scores() metric_log = "" for k, v in score.items(): metric_log += " {}: \t %.4f, ".format(k) % v metrics.reset() # reset the metrics for each train_i steps logs = loss_log + metric_log if args.tensor_board: writer.add_scalar('Training/Train_Loss', train_loss.item(), full_iter) writer.add_scalar('Training/Train_SE_Loss', train_se_loss.item(), full_iter) writer.add_scalar('Training/Loss', loss.item(), full_iter) writer.add_scalar('Training/Lr', lr, full_iter) writer.add_scalars('Training/Metrics', score, full_iter) writer.add_text('Training/Text', logs, full_iter) for name, param in model.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter) """ # each 2000 iterations save model if (train_i + 1) % args.iter_interval_save_model == 0: pbar.set_postfix(Loss=train_loss.item(), lr=lr) state = {"epoch": epoch + 1, "iter": train_i + 1, 'beat_map': beat_map, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()} save_dir = os.path.join(os.path.join(weight_dir, 'train_model'), "ssnet_model_sem_se_{}epoch_{}iter.pkl".format(epoch+1, train_i+1)) torch.save(state, save_dir) """ # end of this training phase state = { "epoch": epoch + 1, "iter": num_batches, 'beat_map': beat_map, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict() } save_dir = os.path.join( os.path.join(args.log_root, 'train_model'), "ssnet_model_sem_se_{}_{}epoch_{}iter.pkl".format( args.model_details, epoch + 1, num_batches)) torch.save(state, save_dir) # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.2 Mini-Batch Validation # +++++++++++++++++++++++++++++++++++++++++++++++++++ # model.eval() val_loss = 0.0 vali_count = 0 with torch.no_grad(): for i_val, (images_val, gt_masks_val) in enumerate(val_loader): vali_count += 1 images_val = images_val.cuda() se_labels_val = se_loss.unique_encode(gt_masks_val) se_labels_val = se_labels_val.cuda() gt_masks_val = gt_masks_val.cuda() se_val, sem_seg_pred_val = model(images_val) # !!!!!! Loss !!!!!! topk_val = topk_base * 512 loss = sem_loss(sem_seg_pred_val, gt_masks_val, topk_val, weight=None) + \ args.alpha * se_loss_parallel(predicts=se_val, enc_cls_target=se_labels_val, size_average=True, reduction='elementwise_mean') val_loss += loss.item() # accumulating the confusion matrix and ious sem_seg_pred_val = F.softmax(sem_seg_pred_val, dim=1) pred = sem_seg_pred_val.data.max(1)[1].cpu().numpy() gt = gt_masks_val.data.cpu().numpy() metrics.update(gt, pred) # ---------------------------------------- # # 1. Validation Losses # ---------------------------------------- # val_loss /= vali_count loss_log = "Epoch [%d/%d], Loss: \t %.4f" % ( epoch + 1, args.num_epochs, val_loss) # ---------------------------------------- # # 2. Validation Metrics # ---------------------------------------- # metric_log = "" score, _ = metrics.get_scores() for k, v in score.items(): metric_log += " {}: \t %.4f, ".format(k) % v metrics.reset() # reset the metrics logs = loss_log + metric_log pbar.set_postfix( Vali_Loss=val_loss, Lr=lr, Vali_mIoU=score['Mean_IoU']) # Train_Loss=train_loss.item() if args.tensor_board: writer.add_scalar('Validation/Loss', val_loss, epoch) writer.add_scalars('Validation/Metrics', score, epoch) writer.add_text('Validation/Text', logs, epoch) for name, param in model.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch) # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.3 End of one Epoch # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # !!!!! Here choose suitable Metric for the best model selection !!!!! if score['Mean_IoU'] >= beat_map: beat_map = score['Mean_IoU'] state = { "epoch": epoch + 1, "beat_map": beat_map, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict() } save_dir = os.path.join( weight_dir, "SSnet_best_sem_se_{}_model.pkl".format(args.model_details)) torch.save(state, save_dir) # Note that step should be called after validate() pbar.close() # +++++++++++++++++++++++++++++++++++++++++++++++++++ # # 4.4 End of Training process # +++++++++++++++++++++++++++++++++++++++++++++++++++ # if args.tensor_board: # export scalar datasets to JSON for external processing # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir)) writer.close() print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #") print("> Training Done!!!") print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
def __init__(self, args): self.args = args args.log_name = str(args.checkname) root_dir = getattr(args, "data_root", '../datasets') wo_head = getattr(args, "resume_wo_head", False) self.logger = utils.create_logger(args.log_root, args.log_name) # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.485, .456, .406], [.229, .224, .225]) ]) # dataset data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', root=root_dir, **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', root=root_dir, **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=BatchNorm2d, base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) #print(model) self.logger.info(model) # optimizer using different LR if not args.wo_backbone: params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] else: params_list = [] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.criterion = SegmentationMultiLosses(nclass=self.nclass) #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass) self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if wo_head: print("WITHout HEAD !!!!!!!!!!") from collections import OrderedDict new = OrderedDict() for k, v in checkpoint['state_dict'].items(): if not k.startswith("head"): new[k] = v checkpoint['state_dict'] = new else: print("With HEAD !!!!!!!!!!") if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0
def __init__(self, args): self.args = args self.args.start_epoch = 0 self.args.cuda = True # data transforms input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.490, .490, .490], [.247, .247, .247]) ]) # TODO: change mean and std # dataset train_chain = Compose([ HorizontalFlip(p=0.5), OneOf([ ElasticTransform( alpha=300, sigma=300 * 0.05, alpha_affine=300 * 0.03), GridDistortion(), OpticalDistortion(distort_limit=2, shift_limit=0.5), ], p=0.3), RandomSizedCrop( min_max_height=(900, 1024), height=1024, width=1024, p=0.5), ShiftScaleRotate(rotate_limit=20, p=0.5), Resize(self.args.size, self.args.size) ], p=1) val_chain = Compose([Resize(self.args.size, self.args.size)], p=1) num_fold = self.args.num_fold df_train = pd.read_csv(os.path.join(args.imagelist_path, 'train.csv')) df_val = pd.read_csv(os.path.join(args.imagelist_path, 'val.csv')) df_full = pd.concat((df_train, df_val), ignore_index=True, axis=0) df_full['lbl'] = (df_full['mask_name'].astype(str) == '-1').astype(int) skf = StratifiedKFold(8, shuffle=True, random_state=777) train_ids, val_ids = list( skf.split(df_full['mask_name'], df_full['lbl']))[num_fold] df_test = pd.read_csv( os.path.join(args.imagelist_path, 'test_true.csv')) df_new_train = pd.concat((df_full.iloc[train_ids], df_test), ignore_index=True, axis=0, sort=False) df_new_val = df_full.iloc[val_ids] df_new_train.to_csv(f'/tmp/train_new_pneumo_{num_fold}.csv') df_new_val.to_csv(f'/tmp/val_new_pneumo_{num_fold}.csv') trainset = SegmentationDataset(f'/tmp/train_new_pneumo_{num_fold}.csv', args.image_path, args.masks_path, input_transform=input_transform, transform_chain=train_chain, base_size=1024) testset = SegmentationDataset(f'/tmp/val_new_pneumo_{num_fold}.csv', args.image_path, args.masks_path, input_transform=input_transform, transform_chain=val_chain, base_size=1024) imgs = trainset.mask_img_map[:, [0, 3]] weights = make_weights_for_balanced_classes(imgs, 2) weights = torch.DoubleTensor(weights) train_sampler = (torch.utils.data.sampler.WeightedRandomSampler( weights, len(weights))) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.trainloader = data.DataLoader( trainset, batch_size=args.batch_size, drop_last=True, sampler=train_sampler, #shuffle=True, **kwargs) self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.nclass = 1 if self.args.model == 'unet': model = UNet(n_classes=self.nclass, norm_layer=SyncBatchNorm) params_list = [ { 'params': model.parameters(), 'lr': args.lr }, ] elif self.args.model == 'encnet': model = EncNet( nclass=self.nclass, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm #nn.BatchNorm2d ) # optimizer using different LR params_list = [ { 'params': model.pretrained.parameters(), 'lr': args.lr }, ] if hasattr(model, 'head'): params_list.append({ 'params': model.head.parameters(), 'lr': args.lr * 10 }) if hasattr(model, 'auxlayer'): params_list.append({ 'params': model.auxlayer.parameters(), 'lr': args.lr * 10 }) print(model) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=0.9, weight_decay=args.wd) # criterions if self.nclass == 1: self.criterion = SegmentationLossesBCE(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, use_dice=args.use_dice) else: self.criterion = SegmentationLosses( se_loss=args.se_loss, aux=args.aux, nclass=self.nclass, se_weight=args.se_weight, aux_weight=args.aux_weight, ) self.model, self.optimizer = model, optimizer self.best_pred = 0.0 self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) #, map_location='cpu') self.args.start_epoch = checkpoint['epoch'] state_dict = {k: v for k, v in checkpoint['state_dict'].items()} self.model.load_state_dict(state_dict) self.optimizer.load_state_dict(checkpoint['optimizer']) for g in self.optimizer.param_groups: g['lr'] = args.lr self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) print(f'Best dice: {checkpoint["best_pred"]}') print(f'LR: {get_lr(self.optimizer):.5f}') self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.8, patience=4, threshold=0.001, threshold_mode='abs', min_lr=0.00001) self.logger = Logger(args.logger_dir) self.step_train = 0 self.best_loss = 20 self.step_val = 0