def __init__(self): super(Net, self).__init__() self.deblurMoudle = self._make_net(_DeblurringMoudle) # self.deblurMoudle = SRNDeblurNet() # self.srMoudle = self._make_net(_SRMoudle) # self.srMoudle = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, norm_type=None) self.srMoudle = MSRN() self.geteMoudle = self._make_net(_GateMoudle) self.reconstructMoudle = self._make_net(_ReconstructMoudle)
opt = parser.parse_args() opt.seed = random.randint(1, 1200) torch.manual_seed(opt.seed) torch.cuda.manual_seed(opt.seed) if opt.resume: if os.path.isfile(opt.resume): print("Loading from checkpoint {}".format(opt.resume)) model = torch.load(opt.resume) model.load_state_dict(model.state_dict()) opt.start_training_step, opt.start_epoch = which_trainingstep_epoch( opt.resume) else: model = MSRN() model = model.to(device) criterion = torch.nn.L1Loss(size_average=True) criterion = criterion.to(device) cri_perception = VGGFeatureExtractor().to(device) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), 0.0001) print('# GFN_deblur parameters:', sum(param.numel() for param in model.parameters())) print() opt.start_epoch = 1 opt.nEpochs = 1000 for epoch in range(opt.start_epoch, opt.nEpochs + 1): trainloader = CreateDataLoader(opt)
def __init__(self): super(Net, self).__init__() self.deblurMoudle = _DeblurringMoudle() self.srMoudle = MSRN() self.reconstructMoudle = self._make_net(_ReconstructMoudle)