コード例 #1
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        self.writer = SummaryWriter(log_dir=os.path.join(args.log_root, args.log_name, time.strftime("%Y-%m-%d-%H-%M",time.localtime())))
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            # transform.Normalize([.485, .456, .406], [.229, .224, .225])
            transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        # dataset
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs)
        testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {}
        self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs)
        # self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs)
        self.valloader = data.DataLoader(testset, batch_size=1, drop_last=False, shuffle=False, **kwargs)
        self.nclass = trainset.num_class

        # model
        model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss,
                                       dilated=args.dilated,
                                       # norm_layer=BatchNorm2d, # for multi-gpu
                                       base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation)

        #####################################################################
        self.logger.info(model)
        # optimizer using different LR
        params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},]
        if hasattr(model, 'head'):
            params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10})
        if hasattr(model, 'auxlayer'):
            params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10})
        optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        if args.model == 'danet':
            self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        elif args.model == 'fcn':
            self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
        #####################################################################

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'], strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'], strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step)
        self.best_pred = 0.0

        self.logger.info(self.args)
コード例 #2
0
opt_seg = Options()  #.parse()
opt_seg.dataset = "bdd100k_seg"
opt_seg.model = "fcn"
opt_seg.backbone = "resnet50"
opt_seg.dilated = True
opt_seg.ft = True
# opt_seg.ft_resume = "/home/chenwy/DynamicLightEnlighten/bdd/bdd100k_seg/fcn_model/res50_di_360px_daytime/model_best.pth.tar"
opt_seg.ft_resume = "/home/chenwy/DynamicLightEnlighten/bdd/bdd100k_seg/fcn_model/res50_di_180px_L100.255/model_best.pth.tar"
opt_seg.eval = True

seg = get_segmentation_model(
    opt_seg.model,
    dataset=opt_seg.dataset,
    backbone=opt_seg.backbone,
    aux=False,
    se_loss=False,
    dilated=opt_seg.dilated,
    # norm_layer=BatchNorm2d, # for multi-gpu
    base_size=720,
    crop_size=180,
    multi_grid=False,
    multi_dilation=False)
seg = DataParallelModel(seg).cuda()
seg.eval()
if opt_seg.ft:
    checkpoint = torch.load(opt_seg.ft_resume)
    seg.module.load_state_dict(checkpoint['state_dict'], strict=False)
    # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
####################################################

opt = TestOptions().parse()
opt.nThreads = 1  # test code only supports nThreads = 1
コード例 #3
0
opt_seg = Options().parse()
opt_seg.dataset = "cityscapes"
opt_seg.model = "fcn"
opt_seg.backbone = "resnet50"
opt_seg.dilated = True
opt_seg.ft = True
opt_seg.ft_resume = "/home/chenwy/DynamicDeHaze/cityscapes/fcn_model/res50_di_mg_500px/model_best.pth.tar"
opt_seg.eval = True

seg = get_segmentation_model(
    opt_seg.model,
    dataset=opt_seg.dataset,
    backbone=opt_seg.backbone,
    aux=opt_seg.aux,
    se_loss=opt_seg.se_loss,
    dilated=opt_seg.dilated,
    # norm_layer=BatchNorm2d, # for multi-gpu
    base_size=720,
    crop_size=500,
    multi_grid=True,
    multi_dilation=[2, 4, 8])
# seg = DataParallelModel(seg).cuda()
seg = torch.nn.DataParallel(seg).cuda()
seg.eval()
if opt_seg.ft:
    checkpoint = torch.load(opt_seg.ft_resume)
    seg.module.load_state_dict(checkpoint['state_dict'], strict=False)
    # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
####################################################

##### Policy #################################################