def construct_model(model_path, device, nframes): from rbpn import Net as RBPN model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=nframes, scale_factor=4) ckpt = torch.load(model_path, map_location='cuda:0') new_ckpt = {} for key in ckpt: if key.startswith('module'): new_key = key[7:] else: new_key = key new_ckpt[new_key] = ckpt[key] model = model.to(device) model.load_state_dict(new_ckpt) model.eval() return model
def main(): """ Lets begin the training process! """ args = parser.parse_args() # Initialize Logger logger.initLogger(args.debug) # Load dataset logger.info('==> Loading datasets') # print(args.file_list) # sys.exit() train_set = get_training_set(args.data_dir, args.nFrames, args.upscale_factor, args.data_augmentation, args.file_list, args.other_dataset, args.patch_size, args.future_frame) training_data_loader = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True) # Use generator as RBPN netG = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nFrames, scale_factor=args.upscale_factor) logger.info('# of Generator parameters: %s', sum(param.numel() for param in netG.parameters())) # Use DataParallel? if args.useDataParallel: gpus_list = range(args.gpus) netG = torch.nn.DataParallel(netG, device_ids=gpus_list) # Use discriminator from SRGAN netD = Discriminator() logger.info('# of Discriminator parameters: %s', sum(param.numel() for param in netD.parameters())) # Generator loss generatorCriterion = nn.L1Loss() if not args.APITLoss else GeneratorLoss() # Specify device device = torch.device( "cuda:0" if torch.cuda.is_available() and args.gpu_mode else "cpu") if args.gpu_mode and torch.cuda.is_available(): utils.printCUDAStats() netG.cuda() netD.cuda() netG.to(device) netD.to(device) generatorCriterion.cuda() # Use Adam optimizer optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) if args.APITLoss: logger.info( "Generator Loss: Adversarial Loss + Perception Loss + Image Loss + TV Loss" ) else: logger.info("Generator Loss: L1 Loss") # print iSeeBetter architecture utils.printNetworkArch(netG, netD) if args.pretrained: modelPath = os.path.join(args.save_folder + args.pretrained_sr) utils.loadPreTrainedModel(gpuMode=args.gpu_mode, model=netG, modelPath=modelPath) # sys.exit() for epoch in range(args.start_epoch, args.nEpochs + 1): runningResults = trainModel(epoch, training_data_loader, netG, netD, optimizerD, optimizerG, generatorCriterion, device, args) if (epoch + 1) % (args.snapshots) == 0: saveModelParams(epoch, runningResults, netG, netD)
shuffle=True, num_workers=args.workers, pin_memory=True) device = torch.device("cuda:0") print("constructing model ....") model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nframes, scale_factor=4) model = nn.DataParallel(model.to(device), gpuids) if args.resume: ckpt = torch.load(args.model_path) new_ckpt = {} for key in ckpt: if not key.startswith('module'): new_key = 'module.' + key else: new_key = key new_ckpt[new_key] = ckpt[key] model.load_state_dict(new_ckpt, strict=False) print("model constructed") # for key, value in model.named_parameters(): # if not ('pre_deblur' in key):