def main(): print("init data folders") print(METHOD) encoder_lv1 = models.Encoder().apply(weight_init).cuda(GPU) encoder_lv2 = models.Encoder().apply(weight_init).cuda(GPU) encoder_lv3 = models.Encoder().apply(weight_init).cuda(GPU) decoder_lv1 = models.Decoder().apply(weight_init).cuda(GPU) decoder_lv2 = models.Decoder().apply(weight_init).cuda(GPU) decoder_lv3 = models.Decoder().apply(weight_init).cuda(GPU) if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")): encoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")): encoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))) print("load encoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")): encoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))) print("load encoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")): decoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")): decoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))) print("load decoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")): decoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))) print("load decoder_lv3 success") if os.path.exists('./test_results/' + EXPDIR) == False: os.system('mkdir ./test_results/' + EXPDIR) iteration = 0.0 test_time = 0.0 total_psnr = 0 total_ssim = 0 print("Testing............") print("===========================") test_dataset = GoProDataset( blur_image_files='./datas/GoPro/test_blur_file.txt', sharp_image_files='./datas/GoPro/test_sharp_file.txt', root_dir='./datas/GoPro', transform=transforms.Compose([transforms.ToTensor()])) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) total_psnr = 0 for iteration, images in enumerate(test_dataloader): with torch.no_grad(): images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) start = time.time() H = images_lv1.size(2) W = images_lv1.size(3) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] feature_lv3_1 = encoder_lv3(images_lv3_1) feature_lv3_2 = encoder_lv3(images_lv3_2) feature_lv3_3 = encoder_lv3(images_lv3_3) feature_lv3_4 = encoder_lv3(images_lv3_4) feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat( (feature_lv2_1, feature_lv2_2), 2) + feature_lv3 residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) psnr = compare_psnr(images['sharp_image'].numpy()[0], deblur_image.detach().cpu().numpy()[0] + 0.5) # ssim = SSIM(np.transpose(images['sharp_image'].numpy()[0], (1, 2, 0)), np.transpose(deblur_image.detach().cpu().numpy()[0], (1, 2, 0))+0.5, multichannel=True, data_range=1.0) pilReal = Image.fromarray( ((np.transpose(images['sharp_image'].numpy()[0], (1, 2, 0))) * 255).astype(np.uint8)) pilFake = Image.fromarray( ((np.transpose(deblur_image.detach().cpu().numpy()[0], (1, 2, 0)) + 0.5) * 255).astype(np.uint8)) ssim = SSIM(pilFake).cw_ssim_value(pilReal) total_psnr += psnr total_ssim += ssim stop = time.time() test_time += stop - start print("testing...... iteration: ", iteration, ", psnr: ", psnr, ", ssim: ", ssim) iteration += 1 print("PSNR is: ", (total_psnr / 1111)) print("SSIM is: ", (total_ssim / 1111))
def main(): print("init data folders") torch.backends.cudnn.benchmark = True torch.cuda.empty_cache() encoder_lv1 = models.Encoder() encoder_lv2 = models.Encoder() encoder_lv3 = models.Encoder() encoder_lv4 = models.Encoder() decoder_lv1 = models.Decoder() decoder_lv2 = models.Decoder() decoder_lv3 = models.Decoder() decoder_lv4 = models.Decoder() encoder_lv1.apply(weight_init).cuda(GPU) encoder_lv2.apply(weight_init).cuda(GPU) encoder_lv3.apply(weight_init).cuda(GPU) encoder_lv4.apply(weight_init).cuda(GPU) decoder_lv1.apply(weight_init).cuda(GPU) decoder_lv2.apply(weight_init).cuda(GPU) decoder_lv3.apply(weight_init).cuda(GPU) decoder_lv4.apply(weight_init).cuda(GPU) encoder_lv1_optim = RAdam(encoder_lv1.parameters(), lr=LEARNING_RATE) encoder_lv1_scheduler = StepLR(encoder_lv1_optim, step_size=1000, gamma=0.1) encoder_lv2_optim = RAdam(encoder_lv2.parameters(), lr=LEARNING_RATE) encoder_lv2_scheduler = StepLR(encoder_lv2_optim, step_size=1000, gamma=0.1) encoder_lv3_optim = RAdam(encoder_lv3.parameters(), lr=LEARNING_RATE) encoder_lv3_scheduler = StepLR(encoder_lv3_optim, step_size=1000, gamma=0.1) encoder_lv4_optim = RAdam(encoder_lv4.parameters(), lr=LEARNING_RATE) encoder_lv4_scheduler = StepLR(encoder_lv4_optim, step_size=1000, gamma=0.1) decoder_lv1_optim = RAdam(decoder_lv1.parameters(), lr=LEARNING_RATE) decoder_lv1_scheduler = StepLR(decoder_lv1_optim, step_size=1000, gamma=0.1) decoder_lv2_optim = RAdam(decoder_lv2.parameters(), lr=LEARNING_RATE) decoder_lv2_scheduler = StepLR(decoder_lv2_optim, step_size=1000, gamma=0.1) decoder_lv3_optim = RAdam(decoder_lv3.parameters(), lr=LEARNING_RATE) decoder_lv3_scheduler = StepLR(decoder_lv3_optim, step_size=1000, gamma=0.1) decoder_lv4_optim = RAdam(decoder_lv4.parameters(), lr=LEARNING_RATE) decoder_lv4_scheduler = StepLR(decoder_lv4_optim, step_size=1000, gamma=0.1) if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")): encoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")): encoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))) print("load encoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")): encoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))) print("load encoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")): encoder_lv4.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))) print("load encoder_lv4 success") # for param in decoder_lv4.layer24.parameters(): # param.requires_grad = False # for param in encoder_lv3.parameters(): # param.requires_grad = False # # print("检查部分参数是否固定......") # print(encoder_lv3.layer1.bias.requires_grad) # for param in decoder_lv3.parameters(): # param.requires_grad = False # for param in encoder_lv2.parameters(): # param.requires_grad = False # # print("检查部分参数是否固定......") # print(encoder_lv2.layer1.bias.requires_grad) # for param in decoder_lv2.parameters(): # param.requires_grad = False if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")): decoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")): decoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))) print("load decoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")): decoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))) print("load decoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")): decoder_lv4.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))) print("load decoder_lv4 success") if os.path.exists('./checkpoints/' + METHOD) == False: os.system('mkdir ./checkpoints/' + METHOD) for epoch in range(args.start_epoch, EPOCHS): epoch += 1 print("Training...") print('lr:', encoder_lv1_scheduler.get_lr()) train_dataset = GoProDataset( blur_image_files='./datas/GoPro/train_blur_file.txt', sharp_image_files='./datas/GoPro/train_sharp_file.txt', root_dir='./datas/GoPro', crop=True, crop_size=IMAGE_SIZE, transform=transforms.Compose([transforms.ToTensor()])) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) start = 0 for iteration, images in enumerate(train_dataloader): mse = nn.MSELoss().cuda(GPU) gt = Variable(images['sharp_image'] - 0.5).cuda(GPU) H = gt.size(2) W = gt.size(3) images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :] images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :] images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :] images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :] images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :] images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :] images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :] images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :] feature_lv4_1 = encoder_lv4(images_lv4_1) feature_lv4_2 = encoder_lv4(images_lv4_2) feature_lv4_3 = encoder_lv4(images_lv4_3) feature_lv4_4 = encoder_lv4(images_lv4_4) feature_lv4_5 = encoder_lv4(images_lv4_5) feature_lv4_6 = encoder_lv4(images_lv4_6) feature_lv4_7 = encoder_lv4(images_lv4_7) feature_lv4_8 = encoder_lv4(images_lv4_8) feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2) feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2) feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2) feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2) feature_lv4_top = torch.cat( (feature_lv4_top_left, feature_lv4_top_right), 3) feature_lv4_bot = torch.cat( (feature_lv4_bot_left, feature_lv4_bot_right), 3) feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2) residual_lv4_top_left = decoder_lv4(feature_lv4_top_left) residual_lv4_top_right = decoder_lv4(feature_lv4_top_right) residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left) residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right) feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left) feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right) feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left) feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right) feature_lv3_top = torch.cat( (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top feature_lv3_bot = torch.cat( (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat( (feature_lv2_1, feature_lv2_2), 2) + feature_lv3 residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) loss = mse(deblur_image, gt) encoder_lv1.zero_grad() encoder_lv2.zero_grad() encoder_lv3.zero_grad() encoder_lv4.zero_grad() decoder_lv1.zero_grad() decoder_lv2.zero_grad() decoder_lv3.zero_grad() decoder_lv4.zero_grad() loss.backward() encoder_lv1_optim.step() encoder_lv2_optim.step() encoder_lv3_optim.step() encoder_lv4_optim.step() decoder_lv1_optim.step() decoder_lv2_optim.step() decoder_lv3_optim.step() decoder_lv4_optim.step() if (iteration + 1) % 10 == 0: stop = time.time() print(METHOD + " epoch:", epoch, "iteration:", iteration + 1, "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start)) start = time.time() encoder_lv1_scheduler.step(epoch) encoder_lv2_scheduler.step(epoch) encoder_lv3_scheduler.step(epoch) encoder_lv4_scheduler.step(epoch) decoder_lv1_scheduler.step(epoch) decoder_lv2_scheduler.step(epoch) decoder_lv3_scheduler.step(epoch) decoder_lv4_scheduler.step(epoch) if (epoch) % 100 == 0: if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False: os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch)) print("Testing...") test_dataset = GoProDataset( blur_image_files='./datas/GoPro/test_blur_file.txt', sharp_image_files='./datas/GoPro/test_sharp_file.txt', root_dir='./datas/GoPro', transform=transforms.Compose([transforms.ToTensor()])) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) test_time = 0 for iteration, images in enumerate(test_dataloader): with torch.no_grad(): start = time.time() images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) H = images_lv1.size(2) W = images_lv1.size(3) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :] images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :] images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :] images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :] images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :] images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :] images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :] images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :] feature_lv4_1 = encoder_lv4(images_lv4_1) feature_lv4_2 = encoder_lv4(images_lv4_2) feature_lv4_3 = encoder_lv4(images_lv4_3) feature_lv4_4 = encoder_lv4(images_lv4_4) feature_lv4_5 = encoder_lv4(images_lv4_5) feature_lv4_6 = encoder_lv4(images_lv4_6) feature_lv4_7 = encoder_lv4(images_lv4_7) feature_lv4_8 = encoder_lv4(images_lv4_8) feature_lv4_top_left = torch.cat( (feature_lv4_1, feature_lv4_2), 2) feature_lv4_top_right = torch.cat( (feature_lv4_3, feature_lv4_4), 2) feature_lv4_bot_left = torch.cat( (feature_lv4_5, feature_lv4_6), 2) feature_lv4_bot_right = torch.cat( (feature_lv4_7, feature_lv4_8), 2) feature_lv4_top = torch.cat( (feature_lv4_top_left, feature_lv4_top_right), 3) feature_lv4_bot = torch.cat( (feature_lv4_bot_left, feature_lv4_bot_right), 3) residual_lv4_top_left = decoder_lv4(feature_lv4_top_left) residual_lv4_top_right = decoder_lv4(feature_lv4_top_right) residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left) residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right) feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left) feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right) feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left) feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right) feature_lv3_top = torch.cat( (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top feature_lv3_bot = torch.cat( (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat( (feature_lv2_1, feature_lv2_2), 2) + torch.cat( (feature_lv3_top, feature_lv3_bot), 2) residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) stop = time.time() test_time += stop - start print( 'RunTime:%.4f' % (stop - start), ' Average Runtime:%.4f' % (test_time / (iteration + 1))) save_deblur_images(deblur_image.data + 0.5, iteration, epoch) # torch.save( encoder_lv1.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv1.pkl")) torch.save( encoder_lv2.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv2.pkl")) torch.save( encoder_lv3.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv3.pkl")) torch.save( encoder_lv4.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv4.pkl")) torch.save( decoder_lv1.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv1.pkl")) torch.save( decoder_lv2.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv2.pkl")) torch.save( decoder_lv3.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv3.pkl")) torch.save( decoder_lv4.state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv4.pkl")) torch.save(encoder_lv1.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")) torch.save(encoder_lv2.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")) torch.save(encoder_lv3.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")) torch.save(encoder_lv4.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")) torch.save(decoder_lv1.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")) torch.save(decoder_lv2.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")) torch.save(decoder_lv3.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")) torch.save(decoder_lv4.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))
def main(): print("init data folders") encoder_lv1 = models.Encoder().apply(weight_init).cuda(GPU) encoder_lv2 = models.Encoder().apply(weight_init).cuda(GPU) encoder_lv3 = models.Encoder().apply(weight_init).cuda(GPU) encoder_lv4 = models.Encoder().apply(weight_init).cuda(GPU) decoder_lv1 = models.Decoder().apply(weight_init).cuda(GPU) decoder_lv2 = models.Decoder().apply(weight_init).cuda(GPU) decoder_lv3 = models.Decoder().apply(weight_init).cuda(GPU) decoder_lv4 = models.Decoder().apply(weight_init).cuda(GPU) if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")): encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")): encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))) print("load encoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")): encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))) print("load encoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")): encoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))) print("load encoder_lv4 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")): decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")): decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))) print("load decoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")): decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))) print("load decoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")): decoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))) print("load decoder_lv4 success") if os.path.exists('./test_results/' + EXPDIR) == False: os.system('mkdir ./test_results/' + EXPDIR) # iteration = 0.0 # test_time = 0.0 print("Testing...") test_dataset = GoProDataset( blur_image_files='./datas/GoPro/test_blur_file.txt', sharp_image_files='./datas/GoPro/test_sharp_file.txt', root_dir='./datas/GoPro', transform=transforms.Compose([ transforms.ToTensor() ])) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) test_time = 0 for iteration, images in enumerate(test_dataloader): with torch.no_grad(): start = time.time() images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) H = images_lv1.size(2) W = images_lv1.size(3) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :] images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :] images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :] images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :] images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :] images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :] images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :] images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :] feature_lv4_1 = encoder_lv4(images_lv4_1) feature_lv4_2 = encoder_lv4(images_lv4_2) feature_lv4_3 = encoder_lv4(images_lv4_3) feature_lv4_4 = encoder_lv4(images_lv4_4) feature_lv4_5 = encoder_lv4(images_lv4_5) feature_lv4_6 = encoder_lv4(images_lv4_6) feature_lv4_7 = encoder_lv4(images_lv4_7) feature_lv4_8 = encoder_lv4(images_lv4_8) feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2) feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2) feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2) feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2) feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3) feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3) residual_lv4_top_left = decoder_lv4(feature_lv4_top_left) residual_lv4_top_right = decoder_lv4(feature_lv4_top_right) residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left) residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right) # output residual of level4 feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2) residual_lv4 = decoder_lv4(feature_lv4) feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left) feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right) feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left) feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right) feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) # output residual of level3 feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv3 = decoder_lv3(feature_lv3) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) stop = time.time() test_time += stop - start print('RunTime:%.4f' % (stop - start), ' Average Runtime:%.4f' % (test_time / (iteration + 1))) # 保存特征图,查看 save_residual_images(residual_lv4.data, iteration, "residual_lv4") save_residual_images(residual_lv3.data, iteration, "residual_lv3") save_residual_images(residual_lv2.data, iteration, "residual_lv2") # save_residual_images(feature_lv1, iteration, "feature_lv1") save_images(deblur_image.data + 0.5, iteration)
def main(): print("init data folders") torch.backends.cudnn.benchmark = True torch.cuda.empty_cache() encoder = {} decoder = {} encoder_optim = {} decoder_optim = {} encoder_scheduler = {} decoder_scheduler = {} for s in ['s1', 's2', 's3', 's4']: encoder[s] = {} decoder[s] = {} encoder_optim[s] = {} decoder_optim[s] = {} encoder_scheduler[s] = {} decoder_scheduler[s] = {} for lv in ['lv1', 'lv2', 'lv3']: encoder[s][lv] = models.Encoder() decoder[s][lv] = models.Decoder() encoder[s][lv].apply(weight_init).cuda(GPU) decoder[s][lv].apply(weight_init).cuda(GPU) encoder_optim[s][lv] = RAdam(encoder[s][lv].parameters(), lr=LEARNING_RATE) decoder_optim[s][lv] = RAdam(decoder[s][lv].parameters(), lr=LEARNING_RATE) encoder_scheduler[s][lv] = StepLR(encoder_optim[s][lv], step_size=1000, gamma=0.1) decoder_scheduler[s][lv] = StepLR(decoder_optim[s][lv], step_size=1000, gamma=0.1) if os.path.exists( str('./checkpoints/' + METHOD + "/encoder_" + s + "_" + lv + ".pkl")): encoder[s][lv].load_state_dict( torch.load( str('./checkpoints/' + METHOD + "/encoder_" + s + "_" + lv + ".pkl"))) print("load encoder_" + s + "_" + lv + " successfully!") if os.path.exists( str('./checkpoints/' + METHOD + "/decoder_" + s + "_" + lv + ".pkl")): decoder[s][lv].load_state_dict( torch.load( str('./checkpoints/' + METHOD + "/decoder_" + s + "_" + lv + ".pkl"))) print("load decoder_" + s + "_" + lv + " successfully!") if os.path.exists('./checkpoints/' + METHOD) == False: os.system('mkdir ./checkpoints/' + METHOD) for epoch in range(args.start_epoch, EPOCHS): torch.cuda.empty_cache() print("Training...") print('lr:', encoder_scheduler[s][lv].get_lr()) train_dataset = GoProDataset( blur_image_files='./datas/GoPro/train_blur_file.txt', sharp_image_files='./datas/GoPro/train_sharp_file.txt', root_dir='./datas/GoPro/', crop=True, crop_size=CROP_SIZE, rotation=True, color_augment=True, transform=transforms.Compose([transforms.ToTensor()])) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True) start = 0 for iteration, inputs in enumerate(train_dataloader): mse = nn.MSELoss().cuda(GPU) images = {} feature = {} residual = {} for s in ['s1', 's2', 's3', 's4']: feature[s] = {} residual[s] = {} images['gt'] = Variable(inputs['sharp_image'] - 0.5).cuda(GPU) images['lv1'] = Variable(inputs['blur_image'] - 0.5).cuda(GPU) H = images['lv1'].size(2) W = images['lv1'].size(3) images['lv2_1'] = images['lv1'][:, :, 0:int(H / 2), :] images['lv2_2'] = images['lv1'][:, :, int(H / 2):H, :] images['lv3_1'] = images['lv2_1'][:, :, :, 0:int(W / 2)] images['lv3_2'] = images['lv2_1'][:, :, :, int(W / 2):W] images['lv3_3'] = images['lv2_2'][:, :, :, 0:int(W / 2)] images['lv3_4'] = images['lv2_2'][:, :, :, int(W / 2):W] s = 's1' feature[s]['lv3_1'] = encoder[s]['lv3'](images['lv3_1']) feature[s]['lv3_2'] = encoder[s]['lv3'](images['lv3_2']) feature[s]['lv3_3'] = encoder[s]['lv3'](images['lv3_3']) feature[s]['lv3_4'] = encoder[s]['lv3'](images['lv3_4']) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( images['lv2_1'] + residual[s]['lv3_top']) + feature[s]['lv3_top'] feature[s]['lv2_2'] = encoder[s]['lv2']( images['lv2_2'] + residual[s]['lv3_bot']) + feature[s]['lv3_bot'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) feature[s]['lv1'] = encoder[s]['lv1']( images['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's2' ps = 's1' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s] ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s] ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2']( feature[s]['lv2']) + residual['s1']['lv1'] feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's3' ps = 's2' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s] ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s] ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2']( feature[s]['lv2']) + residual['s1']['lv1'] feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's4' ps = 's3' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s] ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s] ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2']( feature[s]['lv2']) + residual['s1']['lv1'] feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) loss = mse(residual['s4']['lv1'], images['gt']) + mse( residual['s3']['lv1'], images['gt']) + mse( residual['s2']['lv1'], images['gt']) + mse( residual['s1']['lv1'], images['gt']) for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: encoder[s][lv].zero_grad() decoder[s][lv].zero_grad() loss.backward() for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: encoder_optim[s][lv].step() decoder_optim[s][lv].step() if (iteration + 1) % 10 == 0: stop = time.time() print(METHOD + " epoch:", epoch, "iteration:", iteration + 1, "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start)) start = time.time() for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: encoder_scheduler[s][lv].step(epoch) decoder_scheduler[s][lv].step(epoch) if (epoch) % 100 == 0: if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False: os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch)) for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: torch.save( encoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_" + s + "_" + lv + ".pkl")) torch.save( decoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_" + s + "_" + lv + ".pkl")) print("Testing...") test_dataset = GoProDataset( blur_image_files='./datas/GoPro/test_blur_file.txt', sharp_image_files='./datas/GoPro/test_sharp_file.txt', root_dir='./datas/GoPro/', transform=transforms.Compose([transforms.ToTensor()])) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=16, pin_memory=True) for iteration, inputs in enumerate(test_dataloader): with torch.no_grad(): images['lv1'] = Variable(inputs['blur_image'] - 0.5).cuda(GPU) H = images['lv1'].size(2) W = images['lv1'].size(3) images['lv2_1'] = images['lv1'][:, :, 0:int(H / 2), :] images['lv2_2'] = images['lv1'][:, :, int(H / 2):H, :] images['lv3_1'] = images['lv2_1'][:, :, :, 0:int(W / 2)] images['lv3_2'] = images['lv2_1'][:, :, :, int(W / 2):W] images['lv3_3'] = images['lv2_2'][:, :, :, 0:int(W / 2)] images['lv3_4'] = images['lv2_2'][:, :, :, int(W / 2):W] s = 's1' feature[s]['lv3_1'] = encoder[s]['lv3'](images['lv3_1']) feature[s]['lv3_2'] = encoder[s]['lv3'](images['lv3_2']) feature[s]['lv3_3'] = encoder[s]['lv3'](images['lv3_3']) feature[s]['lv3_4'] = encoder[s]['lv3'](images['lv3_4']) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) residual[s]['lv3_top'] = decoder[s]['lv3']( feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3']( feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( images['lv2_1'] + residual[s]['lv3_top']) + feature[s]['lv3_top'] feature[s]['lv2_2'] = encoder[s]['lv2']( images['lv2_2'] + residual[s]['lv3_bot']) + feature[s]['lv3_bot'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) feature[s]['lv1'] = encoder[s]['lv1']( images['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's2' ps = 's1' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3']( feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3']( feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]['lv3_top'] ) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]['lv3_bot'] ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2'] ) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's3' ps = 's2' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3']( feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3']( feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]['lv3_top'] ) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]['lv3_bot'] ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2'] ) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) s = 's4' ps = 's3' feature[s]['lv3_1'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)]) feature[s]['lv3_2'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, 0:int(H / 2), int(W / 2):W]) feature[s]['lv3_3'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)]) feature[s]['lv3_4'] = encoder[s]['lv3']( residual[ps]['lv1'][:, :, int(H / 2):H, int(W / 2):W]) feature[s]['lv3_top'] = torch.cat( (feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top'] feature[s]['lv3_bot'] = torch.cat( (feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot'] residual[s]['lv3_top'] = decoder[s]['lv3']( feature[s]['lv3_top']) residual[s]['lv3_bot'] = decoder[s]['lv3']( feature[s]['lv3_bot']) feature[s]['lv2_1'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]['lv3_top'] ) + feature[s]['lv3_top'] + feature[ps]['lv2_1'] feature[s]['lv2_2'] = encoder[s]['lv2']( residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]['lv3_bot'] ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2'] feature[s]['lv2'] = torch.cat( (feature[s]['lv2_1'], feature[s]['lv2_2']), 2) residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) feature[s]['lv1'] = encoder[s]['lv1']( residual[ps]['lv1'] + residual[s]['lv2'] ) + feature[s]['lv2'] + feature[ps]['lv1'] residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1']) deblurred_image = residual[s]['lv1'] save_deblur_images(deblurred_image.data + 0.5, iteration, epoch) for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: torch.save( encoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_" + s + "_" + lv + ".pkl")) torch.save( decoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_" + s + "_" + lv + ".pkl")) for s in ['s1', 's2', 's3', 's4']: for lv in ['lv1', 'lv2', 'lv3']: torch.save( encoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + "/encoder_" + s + "_" + lv + ".pkl")) torch.save( decoder[s][lv].state_dict(), str('./checkpoints/' + METHOD + "/decoder_" + s + "_" + lv + ".pkl"))
def main(): print("init data folders") psnr_list = [] encoder_lv1 = models.Encoder() encoder_lv2 = models.Encoder() encoder_lv3 = models.Encoder() decoder_lv1 = models.Decoder() decoder_lv2 = models.Decoder() decoder_lv3 = models.Decoder() encoder_lv1.apply(weight_init).cuda(GPU) encoder_lv2.apply(weight_init).cuda(GPU) encoder_lv3.apply(weight_init).cuda(GPU) decoder_lv1.apply(weight_init).cuda(GPU) decoder_lv2.apply(weight_init).cuda(GPU) decoder_lv3.apply(weight_init).cuda(GPU) encoder_lv1_optim = torch.optim.Adam(encoder_lv1.parameters(), lr=LEARNING_RATE) encoder_lv1_scheduler = StepLR(encoder_lv1_optim, step_size=1000, gamma=0.1) encoder_lv2_optim = torch.optim.Adam(encoder_lv2.parameters(), lr=LEARNING_RATE) encoder_lv2_scheduler = StepLR(encoder_lv2_optim, step_size=1000, gamma=0.1) encoder_lv3_optim = torch.optim.Adam(encoder_lv3.parameters(), lr=LEARNING_RATE) encoder_lv3_scheduler = StepLR(encoder_lv3_optim, step_size=1000, gamma=0.1) decoder_lv1_optim = torch.optim.Adam(decoder_lv1.parameters(), lr=LEARNING_RATE) decoder_lv1_scheduler = StepLR(decoder_lv1_optim, step_size=1000, gamma=0.1) decoder_lv2_optim = torch.optim.Adam(decoder_lv2.parameters(), lr=LEARNING_RATE) decoder_lv2_scheduler = StepLR(decoder_lv2_optim, step_size=1000, gamma=0.1) decoder_lv3_optim = torch.optim.Adam(decoder_lv3.parameters(), lr=LEARNING_RATE) decoder_lv3_scheduler = StepLR(decoder_lv3_optim, step_size=1000, gamma=0.1) if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")): encoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")): encoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))) print("load encoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")): encoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))) print("load encoder_lv3 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")): decoder_lv1.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))) print("load encoder_lv1 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")): decoder_lv2.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))) print("load decoder_lv2 success") if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")): decoder_lv3.load_state_dict( torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))) print("load decoder_lv3 success") if os.path.exists('./checkpoints/' + METHOD) == False: os.system('mkdir ./checkpoints/' + METHOD) for epoch in range(args.start_epoch, EPOCHS): encoder_lv1_scheduler.step(epoch) encoder_lv2_scheduler.step(epoch) encoder_lv3_scheduler.step(epoch) decoder_lv1_scheduler.step(epoch) decoder_lv2_scheduler.step(epoch) decoder_lv3_scheduler.step(epoch) print("Training...") train_dataset = GoProDataset( blur_image_files='./datas/GoPro/train_blur_file.txt', sharp_image_files='./datas/GoPro/train_sharp_file.txt', root_dir='./datas/GoPro/', crop=True, crop_size=IMAGE_SIZE, transform=transforms.Compose([transforms.ToTensor()])) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) start = 0 for iteration, images in enumerate(train_dataloader): mse = nn.MSELoss().cuda(GPU) gt = Variable(images['sharp_image'] - 0.5).cuda(GPU) H = gt.size(2) W = gt.size(3) images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] feature_lv3_1 = encoder_lv3(images_lv3_1) feature_lv3_2 = encoder_lv3(images_lv3_2) feature_lv3_3 = encoder_lv3(images_lv3_3) feature_lv3_4 = encoder_lv3(images_lv3_4) feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat( (feature_lv2_1, feature_lv2_2), 2) + feature_lv3 residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) loss_lv1 = mse(deblur_image, gt) loss = loss_lv1 encoder_lv1.zero_grad() encoder_lv2.zero_grad() encoder_lv3.zero_grad() decoder_lv1.zero_grad() decoder_lv2.zero_grad() decoder_lv3.zero_grad() loss.backward() encoder_lv1_optim.step() encoder_lv2_optim.step() encoder_lv3_optim.step() decoder_lv1_optim.step() decoder_lv2_optim.step() decoder_lv3_optim.step() if (iteration + 1) % 50 == 0: stop = time.time() print("epoch:", epoch, "iteration:", iteration + 1, "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start)) start = time.time() if (epoch) % 100 == 0: if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False: os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch)) print("Testing...") test_dataset = GoProDataset( blur_image_files='./datas/GoPro/test_blur_file.txt', sharp_image_files='./datas/GoPro/test_sharp_file.txt', root_dir='./datas/GoPro/', transform=transforms.Compose([transforms.ToTensor()])) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) test_time = 0.0 total_psnr = 0 for iteration, images in enumerate(test_dataloader): with torch.no_grad(): images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU) start = time.time() H = images_lv1.size(2) W = images_lv1.size(3) images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :] images_lv2_2 = images_lv1[:, :, int(H / 2):H, :] images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)] images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W] images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)] images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W] feature_lv3_1 = encoder_lv3(images_lv3_1) feature_lv3_2 = encoder_lv3(images_lv3_2) feature_lv3_3 = encoder_lv3(images_lv3_3) feature_lv3_4 = encoder_lv3(images_lv3_4) feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2) residual_lv3_top = decoder_lv3(feature_lv3_top) residual_lv3_bot = decoder_lv3(feature_lv3_bot) feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top) feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot) feature_lv2 = torch.cat( (feature_lv2_1, feature_lv2_2), 2) + feature_lv3 residual_lv2 = decoder_lv2(feature_lv2) feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2 deblur_image = decoder_lv1(feature_lv1) stop = time.time() test_time += stop - start psnr = compare_psnr( images['sharp_image'].numpy()[0], deblur_image.detach().cpu().numpy()[0] + 0.5) total_psnr += psnr if (iteration + 1) % 50 == 0: print( 'PSNR:%.4f' % (psnr), ' Average PSNR:%.4f' % (total_psnr / (iteration + 1))) save_deblur_images(deblur_image.data + 0.5, iteration, epoch) psnr_list.append(total_psnr / (iteration + 1)) print("PSNR list:") print(psnr_list) torch.save(encoder_lv1.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")) torch.save(encoder_lv2.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")) torch.save(encoder_lv3.state_dict(), str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")) torch.save(decoder_lv1.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")) torch.save(decoder_lv2.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")) torch.save(decoder_lv3.state_dict(), str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))