def main(cfg): video_name = cfg.video_name upscale_factor = cfg.upscale_factor use_gpu = cfg.gpu_mode test_set = TestsetLoader('data/'+ video_name, upscale_factor) test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False) net = SOFVSR(upscale_factor=upscale_factor) ckpt = torch.load('./log/SOFVSR_x' + str(upscale_factor) + '.pth') net.load_state_dict(ckpt) if use_gpu: net.cuda() for idx_iter, (LR_y_cube, SR_cb, SR_cr) in enumerate(test_loader): LR_y_cube = Variable(LR_y_cube) if use_gpu: LR_y_cube = LR_y_cube.cuda() SR_y = net(LR_y_cube) SR_y = np.array(SR_y.data) SR_y = SR_y[np.newaxis, :, :] SR_ycbcr = np.concatenate((SR_y, SR_cb, SR_cr), axis=0).transpose(1,2,0) SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0 SR_rgb = np.clip(SR_rgb, 0, 255) SR_rgb = ToPILImage()(SR_rgb.astype(np.uint8)) if not os.path.exists('results/' + video_name): os.mkdir('results/' + video_name) SR_rgb.save('results/'+video_name+'/sr_'+ str(idx_iter+2).rjust(2,'0') + '.png')
def main(cfg): use_gpu = cfg.gpu_mode net = SOFVSR(cfg.upscale_factor, is_training=True) if use_gpu: net.cuda() cudnn.benchmark = True train_set = TrainsetLoader(cfg.trainset_dir, cfg.upscale_factor, cfg.patch_size, cfg.n_iters*cfg.batch_size) train_loader = DataLoader(train_set, num_workers=4, batch_size=cfg.batch_size, shuffle=True) # train optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) criterion_L2 = torch.nn.MSELoss() if use_gpu: criterion_L2 = criterion_L2.cuda() milestones = [50000, 100000, 150000, 200000, 250000] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5) loss_list = [] for idx_iter, (LR, HR) in enumerate(train_loader): scheduler.step() LR, HR = Variable(LR), Variable(HR) if use_gpu: LR = LR.cuda() HR = HR.cuda() (res_01_L1, res_01_L2, flow_01_L1, flow_01_L2, flow_01_L3), ( res_21_L1, res_21_L2, flow_21_L1, flow_21_L2, flow_21_L3), SR = net(LR) warped_01 = optical_flow_warp(torch.unsqueeze(HR[:, 0, :, :], dim=1), flow_01_L3) warped_21 = optical_flow_warp(torch.unsqueeze(HR[:, 2, :, :], dim=1), flow_21_L3) # losses loss_SR = criterion_L2(SR, torch.unsqueeze(HR[:, 1, :, :], 1)) loss_OFR_1 = 1 * (criterion_L2(warped_01, torch.unsqueeze(HR[:, 1, :, :], 1)) + 0.01 * L1_regularization(flow_01_L3)) + \ 0.25 * (torch.mean(res_01_L2 ** 2) + 0.01 * L1_regularization(flow_01_L2)) + \ 0.125 * (torch.mean(res_01_L1 ** 2) + 0.01 * L1_regularization(flow_01_L1)) loss_OFR_2 = 1 * (criterion_L2(warped_21, torch.unsqueeze(HR[:, 1, :, :], 1)) + 0.01 * L1_regularization(flow_21_L3)) + \ 0.25 * (torch.mean(res_21_L2 ** 2) + 0.01 * L1_regularization(flow_21_L2)) + \ 0.125 * (torch.mean(res_21_L1 ** 2) + 0.01 * L1_regularization(flow_21_L1)) loss = loss_SR + 0.01 * (loss_OFR_1 + loss_OFR_2) / 2 loss_list.append(loss.data.cpu()) optimizer.zero_grad() loss.backward() optimizer.step() # save checkpoint if idx_iter % 5000 == 0: print('Iteration---%6d, loss---%f' % (idx_iter + 1, np.array(loss_list).mean())) torch.save(net.state_dict(), 'log/BI_x' + str(cfg.upscale_factor) + '_iter' + str(idx_iter) + '.pth') loss_list = []
def main(cfg): video_name = cfg.video_name upscale_factor = cfg.upscale_factor use_gpu = cfg.gpu_mode test_set = TestsetLoader('data/test/' + video_name, upscale_factor) test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False) net = SOFVSR(upscale_factor=upscale_factor) ckpt = torch.load('./log/SOFVSR_x' + str(upscale_factor) + '.pth') net.load_state_dict(ckpt) if use_gpu: net.cuda() for idx_iter, (LR_y_cube, SR_cb, SR_cr) in enumerate(test_loader): LR_y_cube = Variable(LR_y_cube) if use_gpu: LR_y_cube = LR_y_cube.cuda() if cfg.chop_forward: # crop borders to ensure each patch can be divisible by 2 _, _, h, w = LR_y_cube.size() h = int(h // 16) * 16 w = int(w // 16) * 16 LR_y_cube = LR_y_cube[:, :, :h, :w] SR_cb = SR_cb[:, :h * upscale_factor, :w * upscale_factor] SR_cr = SR_cr[:, :h * upscale_factor, :w * upscale_factor] SR_y = chop_forward(LR_y_cube, net, cfg.upscale_factor) else: SR_y = net(LR_y_cube) SR_y = SR_y.cpu() else: SR_y = net(LR_y_cube) SR_y = np.array(SR_y.data) SR_y = SR_y[np.newaxis, :, :] SR_ycbcr = np.concatenate((SR_y, SR_cb, SR_cr), axis=0).transpose(1, 2, 0) SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0 SR_rgb = np.clip(SR_rgb, 0, 255) SR_rgb = ToPILImage()(SR_rgb.astype(np.uint8)) if not os.path.exists('results/' + video_name): os.mkdir('results/' + video_name) SR_rgb.save('results/' + video_name + '/sr_' + str(idx_iter + 2).rjust(2, '0') + '.png')
def main(cfg): # model net = SOFVSR(cfg, is_training=False) ckpt = torch.load('./log/' + cfg.degradation + '_x' + str(cfg.scale) + '.pth') net.load_state_dict(ckpt) if cfg.gpu_mode: net.cuda() with torch.no_grad(): video_list = os.listdir(cfg.testset_dir) for idx_video in range(len(video_list)): video_name = video_list[idx_video] # dataloader test_set = TestsetLoader(cfg, video_name) test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False) for idx_iter, (LR_y_cube, SR_cb, SR_cr) in enumerate(test_loader): # data b, n_frames, h_lr, w_lr = LR_y_cube.size() LR_y_cube = Variable(LR_y_cube) LR_y_cube = LR_y_cube.view(b, -1, 1, h_lr, w_lr) if cfg.gpu_mode: LR_y_cube = LR_y_cube.cuda() if cfg.chop_forward: # crop borders to ensure each patch can be divisible by 2 _, _, _, h, w = LR_y_cube.size() h = int(h // 16) * 16 w = int(w // 16) * 16 LR_y_cube = LR_y_cube[:, :, :, :h, :w] SR_cb = SR_cb[:, :h * cfg.scale, :w * cfg.scale] SR_cr = SR_cr[:, :h * cfg.scale, :w * cfg.scale] SR_y = chop_forward(LR_y_cube, net, cfg.scale).squeeze(0) else: SR_y = net(LR_y_cube).squeeze(0) else: SR_y = net(LR_y_cube).squeeze(0) SR_y = np.array(SR_y.data.cpu()) SR_ycbcr = np.concatenate((SR_y, SR_cb, SR_cr), axis=0).transpose(1, 2, 0) SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0 SR_rgb = np.clip(SR_rgb, 0, 255) SR_rgb = ToPILImage()(np.round(SR_rgb).astype(np.uint8)) if not os.path.exists('results/Vid4'): os.mkdir('results/Vid4') if not os.path.exists('results/Vid4/' + cfg.degradation + '_x' + str(cfg.scale)): os.mkdir('results/Vid4/' + cfg.degradation + '_x' + str(cfg.scale)) if not os.path.exists('results/Vid4/' + cfg.degradation + '_x' + str(cfg.scale) + '/' + video_name): os.mkdir('results/Vid4/' + cfg.degradation + '_x' + str(cfg.scale) + '/' + video_name) SR_rgb.save('results/Vid4/' + cfg.degradation + '_x' + str(cfg.scale) + '/' + video_name + '/sr_' + str(idx_iter + 2).rjust(2, '0') + '.png')
def main(cfg): # model net = SOFVSR(cfg, is_training=True) if cfg.gpu_mode: net.cuda() cudnn.benchmark = True # dataloader train_set = TrainsetLoader(cfg) train_loader = DataLoader(train_set, num_workers=4, batch_size=cfg.batch_size, shuffle=True) # train optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) milestones = [80000, 160000] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) criterion = torch.nn.MSELoss() loss_list = [] for idx_iter, (LR, HR) in enumerate(train_loader): scheduler.step() # data b, n_frames, h_lr, w_lr = LR.size() idx_center = (n_frames - 1) // 2 LR, HR = Variable(LR), Variable(HR) if cfg.gpu_mode: LR = LR.cuda() HR = HR.cuda() LR = LR.view(b, -1, 1, h_lr, w_lr) HR = HR.view(b, -1, 1, h_lr * cfg.scale, w_lr * cfg.scale) # inference flow_L1, flow_L2, flow_L3, SR = net(LR) # loss loss_SR = criterion(SR, HR[:, idx_center, :, :, :]) loss_OFR = torch.zeros(1).cuda() for i in range(n_frames): if i != idx_center: loss_L1 = OFR_loss(F.avg_pool2d(LR[:, i, :, :, :], kernel_size=2), F.avg_pool2d(LR[:, idx_center, :, :, :], kernel_size=2), flow_L1[i]) loss_L2 = OFR_loss(LR[:, i, :, :, :], LR[:, idx_center, :, :, :], flow_L2[i]) loss_L3 = OFR_loss(HR[:, i, :, :, :], HR[:, idx_center, :, :, :], flow_L3[i]) loss_OFR = loss_OFR + loss_L3 + 0.2 * loss_L2 + 0.1 * loss_L1 loss = loss_SR + 0.01 * loss_OFR / (n_frames - 1) loss_list.append(loss.data.cpu()) # backwards optimizer.zero_grad() loss.backward() optimizer.step() # save checkpoint if idx_iter % 5000 == 0: print('Iteration---%6d, loss---%f' % (idx_iter + 1, np.array(loss_list).mean())) save_path = 'log/' + cfg.degradation + '_x' + str(cfg.scale) save_name = cfg.degradation + '_x' + str(cfg.scale) + '_iter' + str(idx_iter) + '.pth' if not os.path.exists(save_path): os.mkdir(save_path) torch.save(net.state_dict(), save_path + '/' + save_name) loss_list = []