def __init__(self, args): self.args = args args.log_name = str(args.checkname) self.logger = utils.create_logger(args.log_root, args.log_name) self.writer = SummaryWriter(log_dir=os.path.join(args.log_root, args.log_name, time.strftime("%Y-%m-%d-%H-%M",time.localtime()))) # data transforms input_transform = transform.Compose([ transform.ToTensor(), # transform.Normalize([.485, .456, .406], [.229, .224, .225]) transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # dataset data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale} trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {} self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs) # self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs) self.valloader = data.DataLoader(testset, batch_size=1, drop_last=False, shuffle=False, **kwargs) self.nclass = trainset.num_class # model model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss, dilated=args.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation) ##################################################################### self.logger.info(model) # optimizer using different LR params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},] if hasattr(model, 'head'): params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10}) if hasattr(model, 'auxlayer'): params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.model == 'danet': self.criterion = SegmentationMultiLosses(nclass=self.nclass) elif args.model == 'fcn': self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) else: self.criterion = torch.nn.CrossEntropyLoss() ##################################################################### self.model, self.optimizer = model, optimizer # using cuda if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() # finetune from a trained model if args.ft: args.start_epoch = 0 checkpoint = torch.load(args.ft_resume) if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=False) else: self.model.load_state_dict(checkpoint['state_dict'], strict=False) self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) # resuming checkpoint if args.resume: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step) self.best_pred = 0.0 self.logger.info(self.args)
opt_seg = Options() #.parse() opt_seg.dataset = "bdd100k_seg" opt_seg.model = "fcn" opt_seg.backbone = "resnet50" opt_seg.dilated = True opt_seg.ft = True # opt_seg.ft_resume = "/home/chenwy/DynamicLightEnlighten/bdd/bdd100k_seg/fcn_model/res50_di_360px_daytime/model_best.pth.tar" opt_seg.ft_resume = "/home/chenwy/DynamicLightEnlighten/bdd/bdd100k_seg/fcn_model/res50_di_180px_L100.255/model_best.pth.tar" opt_seg.eval = True seg = get_segmentation_model( opt_seg.model, dataset=opt_seg.dataset, backbone=opt_seg.backbone, aux=False, se_loss=False, dilated=opt_seg.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=720, crop_size=180, multi_grid=False, multi_dilation=False) seg = DataParallelModel(seg).cuda() seg.eval() if opt_seg.ft: checkpoint = torch.load(opt_seg.ft_resume) seg.module.load_state_dict(checkpoint['state_dict'], strict=False) # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) #################################################### opt = TestOptions().parse() opt.nThreads = 1 # test code only supports nThreads = 1
opt_seg = Options().parse() opt_seg.dataset = "cityscapes" opt_seg.model = "fcn" opt_seg.backbone = "resnet50" opt_seg.dilated = True opt_seg.ft = True opt_seg.ft_resume = "/home/chenwy/DynamicDeHaze/cityscapes/fcn_model/res50_di_mg_500px/model_best.pth.tar" opt_seg.eval = True seg = get_segmentation_model( opt_seg.model, dataset=opt_seg.dataset, backbone=opt_seg.backbone, aux=opt_seg.aux, se_loss=opt_seg.se_loss, dilated=opt_seg.dilated, # norm_layer=BatchNorm2d, # for multi-gpu base_size=720, crop_size=500, multi_grid=True, multi_dilation=[2, 4, 8]) # seg = DataParallelModel(seg).cuda() seg = torch.nn.DataParallel(seg).cuda() seg.eval() if opt_seg.ft: checkpoint = torch.load(opt_seg.ft_resume) seg.module.load_state_dict(checkpoint['state_dict'], strict=False) # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch'])) #################################################### ##### Policy #################################################