def construct_model(model_path, device, nframes): from rbpn import Net as RBPN model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=nframes, scale_factor=4) ckpt = torch.load(model_path, map_location='cuda:0') new_ckpt = {} for key in ckpt: if key.startswith('module'): new_key = key[7:] else: new_key = key new_ckpt[new_key] = ckpt[key] model = model.to(device) model.load_state_dict(new_ckpt) model.eval() return model
n_resblock=5, nFrames=opt.nFrames, scale_factor=opt.upscale_factor) model = torch.nn.DataParallel(model, device_ids=gpus_list) criterion = nn.L1Loss() print('---------- Networks architecture -------------') print_network(model) print('----------------------------------------------') if opt.pretrained: model_name = os.path.join(opt.save_folder + opt.pretrained_sr) if os.path.exists(model_name): #model= torch.load(model_name, map_location=lambda storage, loc: storage) model.load_state_dict( torch.load(model_name, map_location=lambda storage, loc: storage)) print('Pre-trained SR model is loaded.') if cuda: model = model.cuda(gpus_list[0]) criterion = criterion.cuda(gpus_list[0]) optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8) for epoch in range(opt.start_epoch, opt.nEpochs + 1): train(epoch) #test()
shuffle=False) print('===> Building model ', opt.model_type) if opt.model_type == 'RBPN': model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=opt.nFrames, scale_factor=opt.upscale_factor) if cuda: model = torch.nn.DataParallel(model, device_ids=gpus_list) model.load_state_dict( torch.load(opt.model, map_location=lambda storage, loc: storage)) print('Pre-trained SR model is loaded.') if cuda: model = model.cuda(gpus_list[0]) def eval(): model.eval() count = 1 avg_psnr_predicted = 0.0 for batch in testing_data_loader: input, target, neigbor, flow, bicubic = batch[0], batch[1], batch[ 2], batch[3], batch[4] with torch.no_grad():
class Net(nn.Module): def __init__(self, base_filter, feat, num_stages, n_resblock, scale_factor, pretrained=True, freeze=False): super(Net, self).__init__() if scale_factor == 2: kernel = 6 stride = 2 padding = 2 elif scale_factor == 4: kernel = 8 stride = 4 padding = 2 elif scale_factor == 8: kernel = 12 stride = 8 padding = 2 #Initial Feature Extraction self.motion_feat = ConvBlock(4, base_filter, 3, 1, 1, activation='lrelu', norm=None) ###INTERPOLATION #Interp_block motion_net = [ ResnetBlock(base_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(2)] motion_net.append(ConvBlock(base_filter, feat, 3, 1, 1, activation='lrelu', norm=None)) self.motion = nn.Sequential(*motion_net) t_net2 = [ConvBlock(feat*3, feat, 1, 1, 0, bias=True, activation='lrelu', norm=None)] t_net2.append(PyramidModule(feat,activation='lrelu')) t_net2.append(ConvBlock(feat, feat, 3, 1, 1, activation='lrelu', norm=None)) self.t_net_hr = nn.Sequential(*t_net2) self.upsample_layer = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) interp_b = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] interp_b.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.interp_block = nn.Sequential(*interp_b) ###ITERATIVE REFINEMENT #Motion Up FORWARD modules_up_f = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] modules_up_f.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_up_f = nn.Sequential(*modules_up_f) #Motion Up BACKWARD modules_up_b = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] modules_up_b.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_up_b = nn.Sequential(*modules_up_b) #Motion Down modules_down = [ ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(2)] modules_down.append(ConvBlock(feat, feat*2, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_down = nn.Sequential(*modules_down) self.relu_bp = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True)#torch.nn.PReLU() #Reconstruction self.reconstruction_l = ConvBlock(feat*2, 3, 3, 1, 1, activation=None, norm=None) self.reconstruction_h = ConvBlock(feat, 3, 3, 1, 1, activation=None, norm=None) ####ALIGNMENT ###RBPN self.RBPN = RBPN(num_channels=3, base_filter=base_filter, feat = feat, num_stages=num_stages, n_resblock=5, nFrames=2, scale_factor=scale_factor) for m in self.modules(): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: torch.nn.init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('ConvTranspose2d') != -1: torch.nn.init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.zero_() if pretrained: if scale_factor == 4: self.RBPN.load_state_dict(torch.load("weights/pretrained/rbpn_pretrained_F2_4x.pth", map_location=lambda storage, loc: storage)) elif scale_factor == 2: self.RBPN.load_state_dict(torch.load("weights/pretrained/rbpn_pretrained_F2_2x.pth", map_location=lambda storage, loc: storage)) if freeze: self.freeze_model(self.RBPN) def freeze_model(self, model): for child in model.children(): for param in child.parameters(): param.requires_grad = False def forward(self, t_im1, t_im2, t_flow_f, t_flow_b, train=True): result_l = [] result_h1 = [] result_ht = [] result_h2 = [] ###ALIGNMENT aux_H1, H1 = self.RBPN(t_im1,[t_im2],[t_flow_f]) aux_H2, H2 = self.RBPN(t_im2,[t_im1],[t_flow_b]) L1 = self.motion_down(H1) L2 = self.motion_down(H2) ###MOTION & DEPTH motion_feat0 = self.motion_feat(torch.cat((t_flow_f, t_flow_b),1)) M = self.motion(motion_feat0) motion_feat1 = self.motion_feat(torch.cat((t_flow_f/2.0, t_flow_b/2.0),1)) M_half = self.motion(motion_feat1) ###INTERPOLATION Ht = self.interp_block(torch.cat((L1,L2,M),1)) Ht = Ht + self.relu_bp(Ht - self.t_net_hr(torch.cat((H1,H2,self.upsample_layer(M)),1))) L = self.motion_down(Ht) aux_Ht = self.reconstruction_h(Ht) aux_L = self.reconstruction_l(L) result_l.append(aux_L) result_h1.append(aux_H1) result_ht.append(aux_Ht) result_h2.append(aux_H2) ####Projection backward1 = torch.cat((L1, L, M_half),1) H_b = self.motion_up_b(backward1) H1 = H1 + self.relu_bp(H1 - H_b) L1 = L1 + self.relu_bp(L1 - self.motion_down(H_b)) forwardd2 = torch.cat((L, L2, M_half),1) H_f = self.motion_up_f(forwardd2) H2 = H2 + self.relu_bp(H2 - H_f) L2 = L2 + self.relu_bp(L2 - self.motion_down(H_f)) forwardd = torch.cat((L1, L, M_half),1) H_t_f = self.motion_up_f(forwardd) Ht = Ht + self.relu_bp(Ht - H_t_f) L = L + self.relu_bp(L - self.motion_down(H_t_f)) backward = torch.cat((L, L2, M_half),1) H_t_b = self.motion_up_b(backward) Ht = Ht + self.relu_bp(Ht - H_t_b) L = L + self.relu_bp(L - self.motion_down(H_t_b)) output_ht = self.reconstruction_h(Ht) output_h1 = self.reconstruction_h(H1) output_h2 = self.reconstruction_h(H2) output_l = self.reconstruction_l(L) result_l.append(output_l) result_h1.append(output_h1) result_ht.append(output_ht) result_h2.append(output_h2) if train: return result_ht, result_h1, result_h2, result_l else: return output_ht, output_h1, output_h2, output_l
base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nframes, scale_factor=4) model = nn.DataParallel(model.to(device), gpuids) if args.resume: ckpt = torch.load(args.model_path) new_ckpt = {} for key in ckpt: if not key.startswith('module'): new_key = 'module.' + key else: new_key = key new_ckpt[new_key] = ckpt[key] model.load_state_dict(new_ckpt, strict=False) print("model constructed") # for key, value in model.named_parameters(): # if not ('pre_deblur' in key): # value.requires_grad = False summary_writer = SummaryWriter(args.log_dir) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = ExponentialLR(optimizer, gamma=args.gamma) train_model(model, optimizer, scheduler, dataloaders, summary_writer, device, args)
model = torch.nn.DataParallel(model, device_ids=gpus_list) model_name = opt.model if not model_name: # Loading the NTIRE2019 model doesn't actually work (size missmatch) if False and opt.upscale_factor == 4: model_name = 'weights/RBPN_4x_F11_NTIRE2019.pth' else: model_name = 'weights/RBPN_' + str(opt.upscale_factor) + 'x.pth' print('===> Using pretrained model', model_name) model_state_dict = torch.load(model_name, map_location=lambda storage, loc: storage) print('===> Loading model dict') model.load_state_dict(model_state_dict) print('===> Pre-trained SR model is loaded.') print('') if cuda: model = model.cuda(gpus_list[0]) def eval(): model.eval() count=1 avg_psnr_predicted = 0.0 print('===> Starting eval') for batch in testing_data_loader: