def __init__(self, num_channels, base_filter, feat, num_stages, n_resblock, nFrames, scale_factor): super(TouNet, self).__init__() self.forward_rbpn = RBPN(num_channels, base_filter, feat, num_stages=3, n_resblock=5, nFrames=nFrames, scale_factor=scale_factor) self.backward_rbpn = RBPN(num_channels, base_filter, feat, num_stages=3, n_resblock=5, nFrames=nFrames, scale_factor=scale_factor) # fuse results self.output = ConvBlock(num_channels * 2, output_size=num_channels, kernel_size=3, stride=1, padding=1, bias=True, activation=None, norm=None)
def construct_model(model_path, device, nframes): from rbpn import Net as RBPN model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=nframes, scale_factor=4) ckpt = torch.load(model_path, map_location='cuda:0') new_ckpt = {} for key in ckpt: if key.startswith('module'): new_key = key[7:] else: new_key = key new_ckpt[new_key] = ckpt[key] model = model.to(device) model.load_state_dict(new_ckpt) model.eval() return model
opt.data_augmentation, opt.file_list, opt.other_dataset, opt.patch_size, opt.future_frame) #test_set = get_eval_set(opt.test_dir, opt.nFrames, opt.upscale_factor, opt.data_augmentation) training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) #testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=False) print('===> Building model ', opt.model_type) if opt.model_type == 'RBPN': model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=opt.nFrames, scale_factor=opt.upscale_factor) model = torch.nn.DataParallel(model, device_ids=gpus_list) criterion = nn.L1Loss() print('---------- Networks architecture -------------') print_network(model) print('----------------------------------------------') if opt.pretrained: model_name = os.path.join(opt.save_folder + opt.pretrained_sr) if os.path.exists(model_name): #model= torch.load(model_name, map_location=lambda storage, loc: storage)
def __init__(self, base_filter, feat, num_stages, n_resblock, scale_factor, pretrained=True, freeze=False): super(Net, self).__init__() if scale_factor == 2: kernel = 6 stride = 2 padding = 2 elif scale_factor == 4: kernel = 8 stride = 4 padding = 2 elif scale_factor == 8: kernel = 12 stride = 8 padding = 2 #Initial Feature Extraction self.motion_feat = ConvBlock(4, base_filter, 3, 1, 1, activation='lrelu', norm=None) ###INTERPOLATION #Interp_block motion_net = [ ResnetBlock(base_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(2)] motion_net.append(ConvBlock(base_filter, feat, 3, 1, 1, activation='lrelu', norm=None)) self.motion = nn.Sequential(*motion_net) t_net2 = [ConvBlock(feat*3, feat, 1, 1, 0, bias=True, activation='lrelu', norm=None)] t_net2.append(PyramidModule(feat,activation='lrelu')) t_net2.append(ConvBlock(feat, feat, 3, 1, 1, activation='lrelu', norm=None)) self.t_net_hr = nn.Sequential(*t_net2) self.upsample_layer = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) interp_b = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] interp_b.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.interp_block = nn.Sequential(*interp_b) ###ITERATIVE REFINEMENT #Motion Up FORWARD modules_up_f = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] modules_up_f.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_up_f = nn.Sequential(*modules_up_f) #Motion Up BACKWARD modules_up_b = [ ResnetBlock(feat*5, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(n_resblock)] modules_up_b.append(DeconvBlock(feat*5, feat, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_up_b = nn.Sequential(*modules_up_b) #Motion Down modules_down = [ ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='lrelu', norm=None) \ for _ in range(2)] modules_down.append(ConvBlock(feat, feat*2, kernel, stride, padding, activation='lrelu', norm=None)) self.motion_down = nn.Sequential(*modules_down) self.relu_bp = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True)#torch.nn.PReLU() #Reconstruction self.reconstruction_l = ConvBlock(feat*2, 3, 3, 1, 1, activation=None, norm=None) self.reconstruction_h = ConvBlock(feat, 3, 3, 1, 1, activation=None, norm=None) ####ALIGNMENT ###RBPN self.RBPN = RBPN(num_channels=3, base_filter=base_filter, feat = feat, num_stages=num_stages, n_resblock=5, nFrames=2, scale_factor=scale_factor) for m in self.modules(): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: torch.nn.init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('ConvTranspose2d') != -1: torch.nn.init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.zero_() if pretrained: if scale_factor == 4: self.RBPN.load_state_dict(torch.load("weights/pretrained/rbpn_pretrained_F2_4x.pth", map_location=lambda storage, loc: storage)) elif scale_factor == 2: self.RBPN.load_state_dict(torch.load("weights/pretrained/rbpn_pretrained_F2_2x.pth", map_location=lambda storage, loc: storage)) if freeze: self.freeze_model(self.RBPN)
def main(): """ Lets begin the training process! """ args = parser.parse_args() # Initialize Logger logger.initLogger(args.debug) # Load dataset logger.info('==> Loading datasets') # print(args.file_list) # sys.exit() train_set = get_training_set(args.data_dir, args.nFrames, args.upscale_factor, args.data_augmentation, args.file_list, args.other_dataset, args.patch_size, args.future_frame) training_data_loader = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True) # Use generator as RBPN netG = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nFrames, scale_factor=args.upscale_factor) logger.info('# of Generator parameters: %s', sum(param.numel() for param in netG.parameters())) # Use DataParallel? if args.useDataParallel: gpus_list = range(args.gpus) netG = torch.nn.DataParallel(netG, device_ids=gpus_list) # Use discriminator from SRGAN netD = Discriminator() logger.info('# of Discriminator parameters: %s', sum(param.numel() for param in netD.parameters())) # Generator loss generatorCriterion = nn.L1Loss() if not args.APITLoss else GeneratorLoss() # Specify device device = torch.device( "cuda:0" if torch.cuda.is_available() and args.gpu_mode else "cpu") if args.gpu_mode and torch.cuda.is_available(): utils.printCUDAStats() netG.cuda() netD.cuda() netG.to(device) netD.to(device) generatorCriterion.cuda() # Use Adam optimizer optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) if args.APITLoss: logger.info( "Generator Loss: Adversarial Loss + Perception Loss + Image Loss + TV Loss" ) else: logger.info("Generator Loss: L1 Loss") # print iSeeBetter architecture utils.printNetworkArch(netG, netD) if args.pretrained: modelPath = os.path.join(args.save_folder + args.pretrained_sr) utils.loadPreTrainedModel(gpuMode=args.gpu_mode, model=netG, modelPath=modelPath) # sys.exit() for epoch in range(args.start_epoch, args.nEpochs + 1): runningResults = trainModel(epoch, training_data_loader, netG, netD, optimizerD, optimizerG, generatorCriterion, device, args) if (epoch + 1) % (args.snapshots) == 0: saveModelParams(epoch, runningResults, netG, netD)
#training code if args.phase == 'train': dataloaders = data.DataLoader(DataLoader(args), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) device = torch.device("cuda:0") print("constructing model ....") model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nframes, scale_factor=4) model = nn.DataParallel(model.to(device), gpuids) if args.resume: ckpt = torch.load(args.model_path) new_ckpt = {} for key in ckpt: if not key.startswith('module'): new_key = 'module.' + key else: new_key = key new_ckpt[new_key] = ckpt[key]