Esempio n. 1
0
def main():
    print("init data folders")
    print(METHOD)
    encoder_lv1 = models.Encoder().apply(weight_init).cuda(GPU)
    encoder_lv2 = models.Encoder().apply(weight_init).cuda(GPU)
    encoder_lv3 = models.Encoder().apply(weight_init).cuda(GPU)

    decoder_lv1 = models.Decoder().apply(weight_init).cuda(GPU)
    decoder_lv2 = models.Decoder().apply(weight_init).cuda(GPU)
    decoder_lv3 = models.Decoder().apply(weight_init).cuda(GPU)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")

    if os.path.exists('./test_results/' + EXPDIR) == False:
        os.system('mkdir ./test_results/' + EXPDIR)

    iteration = 0.0
    test_time = 0.0
    total_psnr = 0
    total_ssim = 0
    print("Testing............")
    print("===========================")
    test_dataset = GoProDataset(
        blur_image_files='./datas/GoPro/test_blur_file.txt',
        sharp_image_files='./datas/GoPro/test_sharp_file.txt',
        root_dir='./datas/GoPro',
        transform=transforms.Compose([transforms.ToTensor()]))
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    total_psnr = 0
    for iteration, images in enumerate(test_dataloader):
        with torch.no_grad():
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            start = time.time()
            H = images_lv1.size(2)
            W = images_lv1.size(3)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]

            feature_lv3_1 = encoder_lv3(images_lv3_1)
            feature_lv3_2 = encoder_lv3(images_lv3_2)
            feature_lv3_3 = encoder_lv3(images_lv3_3)
            feature_lv3_4 = encoder_lv3(images_lv3_4)
            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3)
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3)
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat(
                (feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)

            psnr = compare_psnr(images['sharp_image'].numpy()[0],
                                deblur_image.detach().cpu().numpy()[0] + 0.5)
            # ssim = SSIM(np.transpose(images['sharp_image'].numpy()[0], (1, 2, 0)), np.transpose(deblur_image.detach().cpu().numpy()[0], (1, 2, 0))+0.5, multichannel=True, data_range=1.0)

            pilReal = Image.fromarray(
                ((np.transpose(images['sharp_image'].numpy()[0],
                               (1, 2, 0))) * 255).astype(np.uint8))
            pilFake = Image.fromarray(
                ((np.transpose(deblur_image.detach().cpu().numpy()[0],
                               (1, 2, 0)) + 0.5) * 255).astype(np.uint8))
            ssim = SSIM(pilFake).cw_ssim_value(pilReal)

            total_psnr += psnr
            total_ssim += ssim
            stop = time.time()
            test_time += stop - start
            print("testing...... iteration: ", iteration, ", psnr: ", psnr,
                  ", ssim: ", ssim)
            iteration += 1

    print("PSNR is: ", (total_psnr / 1111))
    print("SSIM is: ", (total_ssim / 1111))
Esempio n. 2
0
def main():
    print("init data folders")
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    encoder_lv1 = models.Encoder()
    encoder_lv2 = models.Encoder()
    encoder_lv3 = models.Encoder()
    encoder_lv4 = models.Encoder()

    decoder_lv1 = models.Decoder()
    decoder_lv2 = models.Decoder()
    decoder_lv3 = models.Decoder()
    decoder_lv4 = models.Decoder()

    encoder_lv1.apply(weight_init).cuda(GPU)
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)
    encoder_lv4.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)
    decoder_lv4.apply(weight_init).cuda(GPU)

    encoder_lv1_optim = RAdam(encoder_lv1.parameters(), lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv2_optim = RAdam(encoder_lv2.parameters(), lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv3_optim = RAdam(encoder_lv3.parameters(), lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv4_optim = RAdam(encoder_lv4.parameters(), lr=LEARNING_RATE)
    encoder_lv4_scheduler = StepLR(encoder_lv4_optim,
                                   step_size=1000,
                                   gamma=0.1)

    decoder_lv1_optim = RAdam(decoder_lv1.parameters(), lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv2_optim = RAdam(decoder_lv2.parameters(), lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv3_optim = RAdam(decoder_lv3.parameters(), lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv4_optim = RAdam(decoder_lv4.parameters(), lr=LEARNING_RATE)
    decoder_lv4_scheduler = StepLR(decoder_lv4_optim,
                                   step_size=1000,
                                   gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    # for param in decoder_lv4.layer24.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv3.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv3.layer1.bias.requires_grad)
    # for param in decoder_lv3.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv2.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv2.layer1.bias.requires_grad)
    # for param in decoder_lv2.parameters():
    #     param.requires_grad = False

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    for epoch in range(args.start_epoch, EPOCHS):
        epoch += 1

        print("Training...")
        print('lr:', encoder_lv1_scheduler.get_lr())

        train_dataset = GoProDataset(
            blur_image_files='./datas/GoPro/train_blur_file.txt',
            sharp_image_files='./datas/GoPro/train_sharp_file.txt',
            root_dir='./datas/GoPro',
            crop=True,
            crop_size=IMAGE_SIZE,
            transform=transforms.Compose([transforms.ToTensor()]))

        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)
        start = 0

        for iteration, images in enumerate(train_dataloader):
            mse = nn.MSELoss().cuda(GPU)

            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            H = gt.size(2)
            W = gt.size(3)

            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
            images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
            images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
            images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
            images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
            images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
            images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
            images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
            images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)
            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4),
                                              2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8),
                                              2)
            feature_lv4_top = torch.cat(
                (feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat(
                (feature_lv4_bot_left, feature_lv4_bot_right), 3)
            feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2)
            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
            feature_lv3_top = torch.cat(
                (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat(
                (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat(
                (feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)

            loss = mse(deblur_image, gt)

            encoder_lv1.zero_grad()
            encoder_lv2.zero_grad()
            encoder_lv3.zero_grad()
            encoder_lv4.zero_grad()

            decoder_lv1.zero_grad()
            decoder_lv2.zero_grad()
            decoder_lv3.zero_grad()
            decoder_lv4.zero_grad()

            loss.backward()

            encoder_lv1_optim.step()
            encoder_lv2_optim.step()
            encoder_lv3_optim.step()
            encoder_lv4_optim.step()

            decoder_lv1_optim.step()
            decoder_lv2_optim.step()
            decoder_lv3_optim.step()
            decoder_lv4_optim.step()

            if (iteration + 1) % 10 == 0:
                stop = time.time()
                print(METHOD + " epoch:", epoch, "iteration:", iteration + 1,
                      "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start))
                start = time.time()
        encoder_lv1_scheduler.step(epoch)
        encoder_lv2_scheduler.step(epoch)
        encoder_lv3_scheduler.step(epoch)
        encoder_lv4_scheduler.step(epoch)

        decoder_lv1_scheduler.step(epoch)
        decoder_lv2_scheduler.step(epoch)
        decoder_lv3_scheduler.step(epoch)
        decoder_lv4_scheduler.step(epoch)
        if (epoch) % 100 == 0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' +
                              str(epoch)) == False:
                os.system('mkdir ./checkpoints/' + METHOD + '/epoch' +
                          str(epoch))

            print("Testing...")
            test_dataset = GoProDataset(
                blur_image_files='./datas/GoPro/test_blur_file.txt',
                sharp_image_files='./datas/GoPro/test_sharp_file.txt',
                root_dir='./datas/GoPro',
                transform=transforms.Compose([transforms.ToTensor()]))
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True)
            test_time = 0
            for iteration, images in enumerate(test_dataloader):
                with torch.no_grad():
                    start = time.time()
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
                    H = images_lv1.size(2)
                    W = images_lv1.size(3)
                    images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
                    images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
                    images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
                    images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
                    images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
                    images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
                    images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
                    images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
                    images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
                    images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
                    images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

                    feature_lv4_1 = encoder_lv4(images_lv4_1)
                    feature_lv4_2 = encoder_lv4(images_lv4_2)
                    feature_lv4_3 = encoder_lv4(images_lv4_3)
                    feature_lv4_4 = encoder_lv4(images_lv4_4)
                    feature_lv4_5 = encoder_lv4(images_lv4_5)
                    feature_lv4_6 = encoder_lv4(images_lv4_6)
                    feature_lv4_7 = encoder_lv4(images_lv4_7)
                    feature_lv4_8 = encoder_lv4(images_lv4_8)

                    feature_lv4_top_left = torch.cat(
                        (feature_lv4_1, feature_lv4_2), 2)
                    feature_lv4_top_right = torch.cat(
                        (feature_lv4_3, feature_lv4_4), 2)
                    feature_lv4_bot_left = torch.cat(
                        (feature_lv4_5, feature_lv4_6), 2)
                    feature_lv4_bot_right = torch.cat(
                        (feature_lv4_7, feature_lv4_8), 2)

                    feature_lv4_top = torch.cat(
                        (feature_lv4_top_left, feature_lv4_top_right), 3)
                    feature_lv4_bot = torch.cat(
                        (feature_lv4_bot_left, feature_lv4_bot_right), 3)

                    residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
                    residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
                    residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
                    residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

                    feature_lv3_1 = encoder_lv3(images_lv3_1 +
                                                residual_lv4_top_left)
                    feature_lv3_2 = encoder_lv3(images_lv3_2 +
                                                residual_lv4_top_right)
                    feature_lv3_3 = encoder_lv3(images_lv3_3 +
                                                residual_lv4_bot_left)
                    feature_lv3_4 = encoder_lv3(images_lv3_4 +
                                                residual_lv4_bot_right)

                    feature_lv3_top = torch.cat(
                        (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
                    feature_lv3_bot = torch.cat(
                        (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
                    residual_lv3_top = decoder_lv3(feature_lv3_top)
                    residual_lv3_bot = decoder_lv3(feature_lv3_bot)

                    feature_lv2_1 = encoder_lv2(images_lv2_1 +
                                                residual_lv3_top)
                    feature_lv2_2 = encoder_lv2(images_lv2_2 +
                                                residual_lv3_bot)
                    feature_lv2 = torch.cat(
                        (feature_lv2_1, feature_lv2_2), 2) + torch.cat(
                            (feature_lv3_top, feature_lv3_bot), 2)
                    residual_lv2 = decoder_lv2(feature_lv2)

                    feature_lv1 = encoder_lv1(images_lv1 +
                                              residual_lv2) + feature_lv2
                    deblur_image = decoder_lv1(feature_lv1)
                    stop = time.time()
                    test_time += stop - start
                    print(
                        'RunTime:%.4f' % (stop - start),
                        '  Average Runtime:%.4f' % (test_time /
                                                    (iteration + 1)))
                    save_deblur_images(deblur_image.data + 0.5, iteration,
                                       epoch)
                    #
                    torch.save(
                        encoder_lv1.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv1.pkl"))
                    torch.save(
                        encoder_lv2.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv2.pkl"))
                    torch.save(
                        encoder_lv3.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv3.pkl"))
                    torch.save(
                        encoder_lv4.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv4.pkl"))
                    torch.save(
                        decoder_lv1.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv1.pkl"))
                    torch.save(
                        decoder_lv2.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv2.pkl"))
                    torch.save(
                        decoder_lv3.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv3.pkl"))
                    torch.save(
                        decoder_lv4.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv4.pkl"))

        torch.save(encoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))
        torch.save(encoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))
        torch.save(encoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))
        torch.save(encoder_lv4.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))
        torch.save(decoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))
        torch.save(decoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))
        torch.save(decoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))
        torch.save(decoder_lv4.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))
def main():
    print("init data folders")

    encoder_lv1 = models.Encoder().apply(weight_init).cuda(GPU)
    encoder_lv2 = models.Encoder().apply(weight_init).cuda(GPU)
    encoder_lv3 = models.Encoder().apply(weight_init).cuda(GPU)
    encoder_lv4 = models.Encoder().apply(weight_init).cuda(GPU)

    decoder_lv1 = models.Decoder().apply(weight_init).cuda(GPU)
    decoder_lv2 = models.Decoder().apply(weight_init).cuda(GPU)
    decoder_lv3 = models.Decoder().apply(weight_init).cuda(GPU)
    decoder_lv4 = models.Decoder().apply(weight_init).cuda(GPU)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")
    
    if os.path.exists('./test_results/' + EXPDIR) == False:
        os.system('mkdir ./test_results/' + EXPDIR)     
            
    # iteration = 0.0
    # test_time = 0.0
    print("Testing...")
    test_dataset = GoProDataset(
        blur_image_files='./datas/GoPro/test_blur_file.txt',
        sharp_image_files='./datas/GoPro/test_sharp_file.txt',
        root_dir='./datas/GoPro',
        transform=transforms.Compose([
            transforms.ToTensor()
        ]))
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    test_time = 0
    for iteration, images in enumerate(test_dataloader):
        with torch.no_grad():
            start = time.time()
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            H = images_lv1.size(2)
            W = images_lv1.size(3)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
            images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
            images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
            images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
            images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
            images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
            images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
            images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
            images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)

            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)

            feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)

            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)
            # output residual of level4
            feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2)
            residual_lv4 = decoder_lv4(feature_lv4)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)

            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)
            # output residual of level3
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3 = decoder_lv3(feature_lv3)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + torch.cat((feature_lv3_top, feature_lv3_bot),
                                                                                   2)
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)
            stop = time.time()
            test_time += stop - start
            print('RunTime:%.4f' % (stop - start), '  Average Runtime:%.4f' % (test_time / (iteration + 1)))
            # 保存特征图,查看
            save_residual_images(residual_lv4.data, iteration, "residual_lv4")
            save_residual_images(residual_lv3.data, iteration, "residual_lv3")
            save_residual_images(residual_lv2.data, iteration, "residual_lv2")
            # save_residual_images(feature_lv1, iteration, "feature_lv1")
            save_images(deblur_image.data + 0.5, iteration)
Esempio n. 4
0
def main():
    print("init data folders")
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    encoder = {}
    decoder = {}
    encoder_optim = {}
    decoder_optim = {}
    encoder_scheduler = {}
    decoder_scheduler = {}
    for s in ['s1', 's2', 's3', 's4']:
        encoder[s] = {}
        decoder[s] = {}
        encoder_optim[s] = {}
        decoder_optim[s] = {}
        encoder_scheduler[s] = {}
        decoder_scheduler[s] = {}
        for lv in ['lv1', 'lv2', 'lv3']:
            encoder[s][lv] = models.Encoder()
            decoder[s][lv] = models.Decoder()
            encoder[s][lv].apply(weight_init).cuda(GPU)
            decoder[s][lv].apply(weight_init).cuda(GPU)
            encoder_optim[s][lv] = RAdam(encoder[s][lv].parameters(),
                                         lr=LEARNING_RATE)
            decoder_optim[s][lv] = RAdam(decoder[s][lv].parameters(),
                                         lr=LEARNING_RATE)
            encoder_scheduler[s][lv] = StepLR(encoder_optim[s][lv],
                                              step_size=1000,
                                              gamma=0.1)
            decoder_scheduler[s][lv] = StepLR(decoder_optim[s][lv],
                                              step_size=1000,
                                              gamma=0.1)
            if os.path.exists(
                    str('./checkpoints/' + METHOD + "/encoder_" + s + "_" +
                        lv + ".pkl")):
                encoder[s][lv].load_state_dict(
                    torch.load(
                        str('./checkpoints/' + METHOD + "/encoder_" + s + "_" +
                            lv + ".pkl")))
                print("load encoder_" + s + "_" + lv + " successfully!")
            if os.path.exists(
                    str('./checkpoints/' + METHOD + "/decoder_" + s + "_" +
                        lv + ".pkl")):
                decoder[s][lv].load_state_dict(
                    torch.load(
                        str('./checkpoints/' + METHOD + "/decoder_" + s + "_" +
                            lv + ".pkl")))
                print("load decoder_" + s + "_" + lv + " successfully!")

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    for epoch in range(args.start_epoch, EPOCHS):
        torch.cuda.empty_cache()

        print("Training...")
        print('lr:', encoder_scheduler[s][lv].get_lr())

        train_dataset = GoProDataset(
            blur_image_files='./datas/GoPro/train_blur_file.txt',
            sharp_image_files='./datas/GoPro/train_sharp_file.txt',
            root_dir='./datas/GoPro/',
            crop=True,
            crop_size=CROP_SIZE,
            rotation=True,
            color_augment=True,
            transform=transforms.Compose([transforms.ToTensor()]))
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      num_workers=16,
                                      pin_memory=True)
        start = 0
        for iteration, inputs in enumerate(train_dataloader):
            mse = nn.MSELoss().cuda(GPU)
            images = {}
            feature = {}
            residual = {}
            for s in ['s1', 's2', 's3', 's4']:
                feature[s] = {}
                residual[s] = {}

            images['gt'] = Variable(inputs['sharp_image'] - 0.5).cuda(GPU)
            images['lv1'] = Variable(inputs['blur_image'] - 0.5).cuda(GPU)
            H = images['lv1'].size(2)
            W = images['lv1'].size(3)

            images['lv2_1'] = images['lv1'][:, :, 0:int(H / 2), :]
            images['lv2_2'] = images['lv1'][:, :, int(H / 2):H, :]
            images['lv3_1'] = images['lv2_1'][:, :, :, 0:int(W / 2)]
            images['lv3_2'] = images['lv2_1'][:, :, :, int(W / 2):W]
            images['lv3_3'] = images['lv2_2'][:, :, :, 0:int(W / 2)]
            images['lv3_4'] = images['lv2_2'][:, :, :, int(W / 2):W]

            s = 's1'
            feature[s]['lv3_1'] = encoder[s]['lv3'](images['lv3_1'])
            feature[s]['lv3_2'] = encoder[s]['lv3'](images['lv3_2'])
            feature[s]['lv3_3'] = encoder[s]['lv3'](images['lv3_3'])
            feature[s]['lv3_4'] = encoder[s]['lv3'](images['lv3_4'])
            feature[s]['lv3_top'] = torch.cat(
                (feature[s]['lv3_1'], feature[s]['lv3_2']), 3)
            feature[s]['lv3_bot'] = torch.cat(
                (feature[s]['lv3_3'], feature[s]['lv3_4']), 3)
            residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
            residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

            feature[s]['lv2_1'] = encoder[s]['lv2'](
                images['lv2_1'] +
                residual[s]['lv3_top']) + feature[s]['lv3_top']
            feature[s]['lv2_2'] = encoder[s]['lv2'](
                images['lv2_2'] +
                residual[s]['lv3_bot']) + feature[s]['lv3_bot']
            feature[s]['lv2'] = torch.cat(
                (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
            residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

            feature[s]['lv1'] = encoder[s]['lv1'](
                images['lv1'] + residual[s]['lv2']) + feature[s]['lv2']
            residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

            s = 's2'
            ps = 's1'
            feature[s]['lv3_1'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
            feature[s]['lv3_2'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2),
                                    int(W / 2):W])
            feature[s]['lv3_3'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)])
            feature[s]['lv3_4'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H,
                                    int(W / 2):W])
            feature[s]['lv3_top'] = torch.cat(
                (feature[s]['lv3_1'], feature[s]['lv3_2']),
                3) + feature[ps]['lv3_top']
            feature[s]['lv3_bot'] = torch.cat(
                (feature[s]['lv3_3'], feature[s]['lv3_4']),
                3) + feature[ps]['lv3_bot']
            residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
            residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

            feature[s]['lv2_1'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]
                ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
            feature[s]['lv2_2'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]
                ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
            feature[s]['lv2'] = torch.cat(
                (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
            residual[s]['lv2'] = decoder[s]['lv2'](
                feature[s]['lv2']) + residual['s1']['lv1']

            feature[s]['lv1'] = encoder[s]['lv1'](
                residual[ps]['lv1'] +
                residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1']
            residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

            s = 's3'
            ps = 's2'
            feature[s]['lv3_1'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
            feature[s]['lv3_2'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2),
                                    int(W / 2):W])
            feature[s]['lv3_3'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)])
            feature[s]['lv3_4'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H,
                                    int(W / 2):W])
            feature[s]['lv3_top'] = torch.cat(
                (feature[s]['lv3_1'], feature[s]['lv3_2']),
                3) + feature[ps]['lv3_top']
            feature[s]['lv3_bot'] = torch.cat(
                (feature[s]['lv3_3'], feature[s]['lv3_4']),
                3) + feature[ps]['lv3_bot']
            residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
            residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

            feature[s]['lv2_1'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]
                ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
            feature[s]['lv2_2'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]
                ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
            feature[s]['lv2'] = torch.cat(
                (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
            residual[s]['lv2'] = decoder[s]['lv2'](
                feature[s]['lv2']) + residual['s1']['lv1']

            feature[s]['lv1'] = encoder[s]['lv1'](
                residual[ps]['lv1'] +
                residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1']
            residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

            s = 's4'
            ps = 's3'
            feature[s]['lv3_1'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
            feature[s]['lv3_2'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, 0:int(H / 2),
                                    int(W / 2):W])
            feature[s]['lv3_3'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H, 0:int(W / 2)])
            feature[s]['lv3_4'] = encoder[s]['lv3'](
                residual[ps]['lv1'][:, :, int(H / 2):H,
                                    int(W / 2):W])
            feature[s]['lv3_top'] = torch.cat(
                (feature[s]['lv3_1'], feature[s]['lv3_2']),
                3) + feature[ps]['lv3_top']
            feature[s]['lv3_bot'] = torch.cat(
                (feature[s]['lv3_3'], feature[s]['lv3_4']),
                3) + feature[ps]['lv3_bot']
            residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
            residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

            feature[s]['lv2_1'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, 0:int(H / 2), :] + residual[s]
                ['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
            feature[s]['lv2_2'] = encoder[s]['lv2'](
                residual[ps]['lv1'][:, :, int(H / 2):H, :] + residual[s]
                ['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
            feature[s]['lv2'] = torch.cat(
                (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
            residual[s]['lv2'] = decoder[s]['lv2'](
                feature[s]['lv2']) + residual['s1']['lv1']

            feature[s]['lv1'] = encoder[s]['lv1'](
                residual[ps]['lv1'] +
                residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1']
            residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

            loss = mse(residual['s4']['lv1'], images['gt']) + mse(
                residual['s3']['lv1'], images['gt']) + mse(
                    residual['s2']['lv1'], images['gt']) + mse(
                        residual['s1']['lv1'], images['gt'])

            for s in ['s1', 's2', 's3', 's4']:
                for lv in ['lv1', 'lv2', 'lv3']:
                    encoder[s][lv].zero_grad()
                    decoder[s][lv].zero_grad()

            loss.backward()

            for s in ['s1', 's2', 's3', 's4']:
                for lv in ['lv1', 'lv2', 'lv3']:
                    encoder_optim[s][lv].step()
                    decoder_optim[s][lv].step()

            if (iteration + 1) % 10 == 0:
                stop = time.time()
                print(METHOD + "   epoch:", epoch, "iteration:", iteration + 1,
                      "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start))
                start = time.time()
        for s in ['s1', 's2', 's3', 's4']:
            for lv in ['lv1', 'lv2', 'lv3']:
                encoder_scheduler[s][lv].step(epoch)
                decoder_scheduler[s][lv].step(epoch)
        if (epoch) % 100 == 0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' +
                              str(epoch)) == False:
                os.system('mkdir ./checkpoints/' + METHOD + '/epoch' +
                          str(epoch))

            for s in ['s1', 's2', 's3', 's4']:
                for lv in ['lv1', 'lv2', 'lv3']:
                    torch.save(
                        encoder[s][lv].state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_" + s + "_" + lv + ".pkl"))
                    torch.save(
                        decoder[s][lv].state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_" + s + "_" + lv + ".pkl"))

            print("Testing...")
            test_dataset = GoProDataset(
                blur_image_files='./datas/GoPro/test_blur_file.txt',
                sharp_image_files='./datas/GoPro/test_sharp_file.txt',
                root_dir='./datas/GoPro/',
                transform=transforms.Compose([transforms.ToTensor()]))
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=16,
                                         pin_memory=True)

            for iteration, inputs in enumerate(test_dataloader):
                with torch.no_grad():
                    images['lv1'] = Variable(inputs['blur_image'] -
                                             0.5).cuda(GPU)
                    H = images['lv1'].size(2)
                    W = images['lv1'].size(3)
                    images['lv2_1'] = images['lv1'][:, :, 0:int(H / 2), :]
                    images['lv2_2'] = images['lv1'][:, :, int(H / 2):H, :]
                    images['lv3_1'] = images['lv2_1'][:, :, :, 0:int(W / 2)]
                    images['lv3_2'] = images['lv2_1'][:, :, :, int(W / 2):W]
                    images['lv3_3'] = images['lv2_2'][:, :, :, 0:int(W / 2)]
                    images['lv3_4'] = images['lv2_2'][:, :, :, int(W / 2):W]

                    s = 's1'
                    feature[s]['lv3_1'] = encoder[s]['lv3'](images['lv3_1'])
                    feature[s]['lv3_2'] = encoder[s]['lv3'](images['lv3_2'])
                    feature[s]['lv3_3'] = encoder[s]['lv3'](images['lv3_3'])
                    feature[s]['lv3_4'] = encoder[s]['lv3'](images['lv3_4'])
                    feature[s]['lv3_top'] = torch.cat(
                        (feature[s]['lv3_1'], feature[s]['lv3_2']), 3)
                    feature[s]['lv3_bot'] = torch.cat(
                        (feature[s]['lv3_3'], feature[s]['lv3_4']), 3)
                    residual[s]['lv3_top'] = decoder[s]['lv3'](
                        feature[s]['lv3_top'])
                    residual[s]['lv3_bot'] = decoder[s]['lv3'](
                        feature[s]['lv3_bot'])

                    feature[s]['lv2_1'] = encoder[s]['lv2'](
                        images['lv2_1'] +
                        residual[s]['lv3_top']) + feature[s]['lv3_top']
                    feature[s]['lv2_2'] = encoder[s]['lv2'](
                        images['lv2_2'] +
                        residual[s]['lv3_bot']) + feature[s]['lv3_bot']
                    feature[s]['lv2'] = torch.cat(
                        (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
                    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

                    feature[s]['lv1'] = encoder[s]['lv1'](
                        images['lv1'] + residual[s]['lv2']) + feature[s]['lv2']
                    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

                    s = 's2'
                    ps = 's1'
                    feature[s]['lv3_1'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
                    feature[s]['lv3_2'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2),
                                            int(W / 2):W])
                    feature[s]['lv3_3'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H, 0:int(W / 2)])
                    feature[s]['lv3_4'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H,
                                            int(W / 2):W])
                    feature[s]['lv3_top'] = torch.cat(
                        (feature[s]['lv3_1'], feature[s]['lv3_2']),
                        3) + feature[ps]['lv3_top']
                    feature[s]['lv3_bot'] = torch.cat(
                        (feature[s]['lv3_3'], feature[s]['lv3_4']),
                        3) + feature[ps]['lv3_bot']
                    residual[s]['lv3_top'] = decoder[s]['lv3'](
                        feature[s]['lv3_top'])
                    residual[s]['lv3_bot'] = decoder[s]['lv3'](
                        feature[s]['lv3_bot'])

                    feature[s]['lv2_1'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), :] +
                        residual[s]['lv3_top']
                    ) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
                    feature[s]['lv2_2'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, int(H / 2):H, :] +
                        residual[s]['lv3_bot']
                    ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
                    feature[s]['lv2'] = torch.cat(
                        (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
                    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

                    feature[s]['lv1'] = encoder[s]['lv1'](
                        residual[ps]['lv1'] + residual[s]['lv2']
                    ) + feature[s]['lv2'] + feature[ps]['lv1']
                    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

                    s = 's3'
                    ps = 's2'
                    feature[s]['lv3_1'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
                    feature[s]['lv3_2'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2),
                                            int(W / 2):W])
                    feature[s]['lv3_3'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H, 0:int(W / 2)])
                    feature[s]['lv3_4'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H,
                                            int(W / 2):W])
                    feature[s]['lv3_top'] = torch.cat(
                        (feature[s]['lv3_1'], feature[s]['lv3_2']),
                        3) + feature[ps]['lv3_top']
                    feature[s]['lv3_bot'] = torch.cat(
                        (feature[s]['lv3_3'], feature[s]['lv3_4']),
                        3) + feature[ps]['lv3_bot']
                    residual[s]['lv3_top'] = decoder[s]['lv3'](
                        feature[s]['lv3_top'])
                    residual[s]['lv3_bot'] = decoder[s]['lv3'](
                        feature[s]['lv3_bot'])

                    feature[s]['lv2_1'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), :] +
                        residual[s]['lv3_top']
                    ) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
                    feature[s]['lv2_2'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, int(H / 2):H, :] +
                        residual[s]['lv3_bot']
                    ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
                    feature[s]['lv2'] = torch.cat(
                        (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
                    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

                    feature[s]['lv1'] = encoder[s]['lv1'](
                        residual[ps]['lv1'] + residual[s]['lv2']
                    ) + feature[s]['lv2'] + feature[ps]['lv1']
                    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

                    s = 's4'
                    ps = 's3'
                    feature[s]['lv3_1'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), 0:int(W / 2)])
                    feature[s]['lv3_2'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2),
                                            int(W / 2):W])
                    feature[s]['lv3_3'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H, 0:int(W / 2)])
                    feature[s]['lv3_4'] = encoder[s]['lv3'](
                        residual[ps]['lv1'][:, :,
                                            int(H / 2):H,
                                            int(W / 2):W])
                    feature[s]['lv3_top'] = torch.cat(
                        (feature[s]['lv3_1'], feature[s]['lv3_2']),
                        3) + feature[ps]['lv3_top']
                    feature[s]['lv3_bot'] = torch.cat(
                        (feature[s]['lv3_3'], feature[s]['lv3_4']),
                        3) + feature[ps]['lv3_bot']
                    residual[s]['lv3_top'] = decoder[s]['lv3'](
                        feature[s]['lv3_top'])
                    residual[s]['lv3_bot'] = decoder[s]['lv3'](
                        feature[s]['lv3_bot'])

                    feature[s]['lv2_1'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, 0:int(H / 2), :] +
                        residual[s]['lv3_top']
                    ) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
                    feature[s]['lv2_2'] = encoder[s]['lv2'](
                        residual[ps]['lv1'][:, :, int(H / 2):H, :] +
                        residual[s]['lv3_bot']
                    ) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
                    feature[s]['lv2'] = torch.cat(
                        (feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
                    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

                    feature[s]['lv1'] = encoder[s]['lv1'](
                        residual[ps]['lv1'] + residual[s]['lv2']
                    ) + feature[s]['lv2'] + feature[ps]['lv1']
                    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

                    deblurred_image = residual[s]['lv1']

                    save_deblur_images(deblurred_image.data + 0.5, iteration,
                                       epoch)
                    for s in ['s1', 's2', 's3', 's4']:
                        for lv in ['lv1', 'lv2', 'lv3']:
                            torch.save(
                                encoder[s][lv].state_dict(),
                                str('./checkpoints/' + METHOD + '/epoch' +
                                    str(epoch) + "/encoder_" + s + "_" + lv +
                                    ".pkl"))
                            torch.save(
                                decoder[s][lv].state_dict(),
                                str('./checkpoints/' + METHOD + '/epoch' +
                                    str(epoch) + "/decoder_" + s + "_" + lv +
                                    ".pkl"))

        for s in ['s1', 's2', 's3', 's4']:
            for lv in ['lv1', 'lv2', 'lv3']:
                torch.save(
                    encoder[s][lv].state_dict(),
                    str('./checkpoints/' + METHOD + "/encoder_" + s + "_" +
                        lv + ".pkl"))
                torch.save(
                    decoder[s][lv].state_dict(),
                    str('./checkpoints/' + METHOD + "/decoder_" + s + "_" +
                        lv + ".pkl"))
Esempio n. 5
0
def main():
    print("init data folders")

    psnr_list = []
    encoder_lv1 = models.Encoder()
    encoder_lv2 = models.Encoder()
    encoder_lv3 = models.Encoder()

    decoder_lv1 = models.Decoder()
    decoder_lv2 = models.Decoder()
    decoder_lv3 = models.Decoder()

    encoder_lv1.apply(weight_init).cuda(GPU)
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)

    encoder_lv1_optim = torch.optim.Adam(encoder_lv1.parameters(),
                                         lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv2_optim = torch.optim.Adam(encoder_lv2.parameters(),
                                         lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv3_optim = torch.optim.Adam(encoder_lv3.parameters(),
                                         lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)

    decoder_lv1_optim = torch.optim.Adam(decoder_lv1.parameters(),
                                         lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv2_optim = torch.optim.Adam(decoder_lv2.parameters(),
                                         lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv3_optim = torch.optim.Adam(decoder_lv3.parameters(),
                                         lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    for epoch in range(args.start_epoch, EPOCHS):
        encoder_lv1_scheduler.step(epoch)
        encoder_lv2_scheduler.step(epoch)
        encoder_lv3_scheduler.step(epoch)

        decoder_lv1_scheduler.step(epoch)
        decoder_lv2_scheduler.step(epoch)
        decoder_lv3_scheduler.step(epoch)

        print("Training...")

        train_dataset = GoProDataset(
            blur_image_files='./datas/GoPro/train_blur_file.txt',
            sharp_image_files='./datas/GoPro/train_sharp_file.txt',
            root_dir='./datas/GoPro/',
            crop=True,
            crop_size=IMAGE_SIZE,
            transform=transforms.Compose([transforms.ToTensor()]))
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True)
        start = 0

        for iteration, images in enumerate(train_dataloader):
            mse = nn.MSELoss().cuda(GPU)

            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            H = gt.size(2)
            W = gt.size(3)

            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]

            feature_lv3_1 = encoder_lv3(images_lv3_1)
            feature_lv3_2 = encoder_lv3(images_lv3_2)
            feature_lv3_3 = encoder_lv3(images_lv3_3)
            feature_lv3_4 = encoder_lv3(images_lv3_4)
            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3)
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3)
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat(
                (feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)

            loss_lv1 = mse(deblur_image, gt)

            loss = loss_lv1

            encoder_lv1.zero_grad()
            encoder_lv2.zero_grad()
            encoder_lv3.zero_grad()

            decoder_lv1.zero_grad()
            decoder_lv2.zero_grad()
            decoder_lv3.zero_grad()

            loss.backward()

            encoder_lv1_optim.step()
            encoder_lv2_optim.step()
            encoder_lv3_optim.step()

            decoder_lv1_optim.step()
            decoder_lv2_optim.step()
            decoder_lv3_optim.step()

            if (iteration + 1) % 50 == 0:
                stop = time.time()
                print("epoch:", epoch, "iteration:", iteration + 1,
                      "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start))
                start = time.time()

        if (epoch) % 100 == 0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' +
                              str(epoch)) == False:
                os.system('mkdir ./checkpoints/' + METHOD + '/epoch' +
                          str(epoch))

            print("Testing...")
            test_dataset = GoProDataset(
                blur_image_files='./datas/GoPro/test_blur_file.txt',
                sharp_image_files='./datas/GoPro/test_sharp_file.txt',
                root_dir='./datas/GoPro/',
                transform=transforms.Compose([transforms.ToTensor()]))
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=1,
                                         shuffle=False)
            test_time = 0.0
            total_psnr = 0
            for iteration, images in enumerate(test_dataloader):
                with torch.no_grad():
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
                    start = time.time()
                    H = images_lv1.size(2)
                    W = images_lv1.size(3)
                    images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
                    images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
                    images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
                    images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
                    images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
                    images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]

                    feature_lv3_1 = encoder_lv3(images_lv3_1)
                    feature_lv3_2 = encoder_lv3(images_lv3_2)
                    feature_lv3_3 = encoder_lv3(images_lv3_3)
                    feature_lv3_4 = encoder_lv3(images_lv3_4)
                    feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2),
                                                3)
                    feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4),
                                                3)
                    feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot),
                                            2)
                    residual_lv3_top = decoder_lv3(feature_lv3_top)
                    residual_lv3_bot = decoder_lv3(feature_lv3_bot)

                    feature_lv2_1 = encoder_lv2(images_lv2_1 +
                                                residual_lv3_top)
                    feature_lv2_2 = encoder_lv2(images_lv2_2 +
                                                residual_lv3_bot)
                    feature_lv2 = torch.cat(
                        (feature_lv2_1, feature_lv2_2), 2) + feature_lv3
                    residual_lv2 = decoder_lv2(feature_lv2)

                    feature_lv1 = encoder_lv1(images_lv1 +
                                              residual_lv2) + feature_lv2
                    deblur_image = decoder_lv1(feature_lv1)

                    stop = time.time()
                    test_time += stop - start
                    psnr = compare_psnr(
                        images['sharp_image'].numpy()[0],
                        deblur_image.detach().cpu().numpy()[0] + 0.5)
                    total_psnr += psnr
                    if (iteration + 1) % 50 == 0:
                        print(
                            'PSNR:%.4f' % (psnr), '  Average PSNR:%.4f' %
                            (total_psnr / (iteration + 1)))
                    save_deblur_images(deblur_image.data + 0.5, iteration,
                                       epoch)

            psnr_list.append(total_psnr / (iteration + 1))
            print("PSNR list:")
            print(psnr_list)

        torch.save(encoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))
        torch.save(encoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))
        torch.save(encoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))

        torch.save(decoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))
        torch.save(decoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))
        torch.save(decoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))