Ejemplo n.º 1
0
def bilinear_upsampling(opt, dataloader, scale):
    for batch_no, data in enumerate(dataloader['test']):
        high_img, _ = data
        inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                   opt.imageSize)
        for j in range(opt.batchSize):
            inputs[j] = scale(high_img[j])
            high_img[j] = normalize(high_img[j])
        outputs = F.upsample(inputs,
                             scale_factor=opt.upSampling,
                             mode='bilinear',
                             align_corners=True)
        transform = transforms.Compose([
            transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                 std=[4.367, 4.464, 4.444]),
            transforms.ToPILImage()
        ])
        transform(outputs[0]).save('output/train/bilinear_fake/' +
                                   str(batch_no) + '.png')
        transform(high_img[0]).save('output/train/bilinear_real/' +
                                    str(batch_no) + '.png')

        # for output, himg in zip (outputs, high_img):
        #     psnr_val = psnr(output,himg)
        #mssim = avg_msssim(himg, output)
        print(psnr(un_normalize(outputs), un_normalize(high_img)))
Ejemplo n.º 2
0
def show_generated_images(dataset, net, device, show_n=5):
    image_idx = np.random.choice(len(dataset), show_n)
    image_idx
    images = []
    for idx in image_idx:
        images.append(dataset[idx])

    images = torch.stack(images).to(device)
    original_images = un_normalize(images)
    generated_images = un_normalize(net(images))

    fig, axes = plt.subplots(2, len(original_images))

    for i in range(len(original_images)):
        axes[0, i].imshow(original_images[i])
        axes[1, i].imshow(generated_images[i])

    plt.show()
Ejemplo n.º 3
0
def test_multiple(generator, discriminator, opt, dataloader, scale):
    generator.load_state_dict(torch.load(opt.generatorWeights))
    discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    curr_time = time.time()
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0
    mean_psnr = 0.0
    mean_msssim = 0.0
    high_img = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                 opt.imageSize)
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
    high_res_fake = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                      opt.imageSize)
    max_psnr = 0.0
    mean_psnr = 0.0
    min_psnr = 999.0
    mean_ssim = 0.0
    for batch_no, data in enumerate(dataloader['test']):
        high_img, _ = data
        generator.train(False)
        discriminator.train(False)
        for j in range(opt.batchSize):
            inputs[j] = scale(high_img[j])
            high_img[j] = normalize(high_img[j])

        if opt.cuda:
            high_res_real = Variable(high_img.cuda())
            high_res_fake = generator(
                Variable(inputs[0][np.newaxis, :]).cuda(),
                Variable(inputs[1][np.newaxis, :]).cuda(),
                Variable(inputs[2][np.newaxis, :]).cuda(),
                Variable(inputs[3][np.newaxis, :]).cuda())
            target_real = Variable(torch.rand(opt.batchSize, 1) * 0.5 +
                                   0.7).cuda()
            target_fake = Variable(torch.rand(opt.batchSize, 1) * 0.3).cuda()

            discriminator_loss = adversarial_criterion(
                discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(),
                              Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()),
                target_real) + \
                adversarial_criterion(
                discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :],
                              high_res_fake[3][np.newaxis, :]), target_fake)
            mean_discriminator_loss += discriminator_loss.data.item()

            #high_res_fake_cat = torch.cat([image for image in high_res_fake], 0)
            fake_features = feature_extractor(high_res_fake)
            real_features = Variable(feature_extractor(high_res_real).data)

            generator_content_loss = content_criterion(
                high_res_fake, high_res_real) + 0.006 * content_criterion(
                    fake_features, real_features)
            mean_generator_content_loss += generator_content_loss.data.item()
            generator_adversarial_loss = adversarial_criterion(
                discriminator(high_res_fake[0][np.newaxis, :],
                              high_res_fake[1][np.newaxis, :],
                              high_res_fake[2][np.newaxis, :],
                              high_res_fake[3][np.newaxis, :]), ones_const)
            mean_generator_adversarial_loss += generator_adversarial_loss.data.item(
            )

            generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
            mean_generator_total_loss += generator_total_loss.data.item()
        imsave(high_res_fake.cpu().data,
               train=False,
               epoch=batch_no,
               image_type='fake')
        imsave(high_img, train=False, epoch=batch_no, image_type='real')
        imsave(inputs, train=False, epoch=batch_no, image_type='low')

        mssim = avg_msssim(high_res_real, high_res_fake)
        ssim = pytorch_ssim.ssim(high_res_fake, high_res_real).data.item()
        mean_ssim += ssim
        psnr_val = psnr(un_normalize(high_res_fake),
                        un_normalize(high_res_real))
        mean_psnr += psnr_val
        max_psnr = psnr_val if psnr_val > max_psnr else max_psnr
        min_psnr = psnr_val if psnr_val < min_psnr else min_psnr
        sys.stdout.write(
            '\rTesting batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f psnr %.4f ssim %.4f'
            % (batch_no, len(dataloader['test']), generator_content_loss,
               discriminator_loss, psnr_val, ssim))
    print("Min psnr is: ", min_psnr)
    print("Mean psnr is: ", mean_psnr / 72)
    print("Max psnr is: ", max_psnr)
    print("Mean ssim is: ", mean_ssim / 72)
def train_single(generator, discriminator, opt, dataloader, writer, scale):
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    optimizer = optim.Adam(generator.parameters(), lr=opt.generatorLR)
    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=opt.discriminatorLR)
    scheduler_gen = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=0.5,
                                      patience=3,
                                      verbose=True)
    scheduler_dis = ReduceLROnPlateau(optim_discriminator,
                                      'min',
                                      factor=0.5,
                                      patience=3,
                                      verbose=True)
    curr_time = time.time()

    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_adversarial_loss = 0.0
        mean_generator_total_loss = 0.0
        mean_discriminator_loss = 0.0
        high_res_fake = 0
        for phase in ['train', 'test']:
            if phase == 'test':
                generator.train(False)
                discriminator.train(False)
            else:
                generator.train(True)
                discriminator.train(True)
            for batch_no, data in enumerate(dataloader[phase]):
                high_img, _ = data

                input1 = high_img[0, :, :, :]
                input2 = high_img[1, :, :, :]
                input3 = high_img[2, :, :, :]
                input4 = high_img[3, :, :, :]
                inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                           opt.imageSize)
                # imshow(input3)
                for j in range(opt.batchSize):
                    inputs[j] = scale(high_img[j])
                    high_img[j] = normalize(high_img[j])
                high_comb = torch.cat(
                    [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

                high_comb = Variable(high_comb[np.newaxis, :]).cuda()
                # imshow(high_comb.cpu().data)
                input_comb = torch.cat([
                    scale(input1),
                    scale(input2),
                    scale(input3),
                    scale(input4)
                ], 0)
                input_comb = input_comb[np.newaxis, :]
                if opt.cuda:
                    optimizer.zero_grad()
                    high_res_real = Variable(high_img.cuda())
                    high_res_fake = generator(Variable(input_comb).cuda())
                    target_real = Variable(torch.rand(1, 1) * 0.5 + 0.7).cuda()
                    target_fake = Variable(torch.rand(1, 1) * 0.3).cuda()

                    outputs = torch.chunk(high_res_fake, 4, 1)
                    outputs = torch.cat(
                        [outputs[0], outputs[1], outputs[2], outputs[3]], 0)

                    discriminator.zero_grad()

                    discriminator_loss = adversarial_criterion(discriminator(high_comb), target_real) + \
                                         adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake)
                    mean_discriminator_loss += discriminator_loss.data[0]

                    if phase == 'train':
                        discriminator_loss.backward()
                        optim_discriminator.step()

                    generator_content_loss = content_criterion(
                        high_res_fake, high_comb)
                    mean_generator_content_loss += generator_content_loss.data[
                        0]
                    generator_adversarial_loss = adversarial_criterion(
                        discriminator(high_res_fake), ones_const)
                    mean_generator_adversarial_loss += generator_adversarial_loss.data[
                        0]

                    generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
                    mean_generator_total_loss += generator_total_loss.data[0]

                    if phase == 'train':
                        generator_total_loss.backward()
                        optimizer.step()

                    if (batch_no % 10 == 0):
                        sys.stdout.write(
                            '\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f Discriminator loss: %.4f'
                            % (phase, epoch, opt.nEpochs, batch_no,
                               len(dataloader[phase]), generator_content_loss,
                               discriminator_loss))

            # imshow(high_res_fake.cpu().data)
            scheduler_gen.step(mean_generator_total_loss)
            scheduler_dis.step(mean_discriminator_loss)
            psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs))
            # imsave(outputs.cpu().data, train=True, epoch=epoch, image_type='fake')
            # imsave(high_img, train=True, epoch=epoch, image_type='real')
            # imsave(inputs, train=True, epoch=epoch, image_type='low')

            writer.add_scalar(phase + " per epoch/generator lr",
                              optimizer.param_groups[0]['lr'], epoch)
            writer.add_scalar(phase + " per epoch/discriminator lr",
                              optim_discriminator.param_groups[0]['lr'], epoch)
            writer.add_scalar(phase + " per epoch/PSNR", psnr_val, epoch)
            if phase == 'train':
                writer.add_scalar(
                    phase + " per epoch/discriminator training loss",
                    mean_discriminator_loss, epoch)
                writer.add_scalar(phase + " per epoch/generator training loss",
                                  mean_generator_total_loss, epoch)

            writer.add_scalar("per epoch/time taken",
                              time.time() - curr_time, epoch)
            torch.save(generator.state_dict(),
                       '%s/generator_single.pth' % opt.out)
            torch.save(discriminator.state_dict(),
                       '%s/discriminator_single.pth' % opt.out)
def train_firstmodel(generator, opt, dataloader, writer, scale):
    content_criterion = nn.MSELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        content_criterion.cuda()

    optimizer = optim.SGD(generator.parameters(), lr=opt.generatorLR)
    scheduler_gen = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=0.5,
                                      patience=3,
                                      verbose=True)
    curr_time = time.time()

    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_total_loss = 0.0

        high_res_fake = 0
        for phase in ['train', 'test']:

            if phase == 'test':
                generator.train(False)
            else:
                generator.train(True)

            for batch_no, data in enumerate(dataloader[phase]):
                high_img, _ = data

                input1 = high_img[0, :, :, :]
                input2 = high_img[1, :, :, :]
                input3 = high_img[2, :, :, :]
                input4 = high_img[3, :, :, :]
                # imshow(input3)
                for j in range(opt.batchSize):
                    high_img[j] = normalize(high_img[j])
                high_comb = torch.cat(
                    [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

                high_comb = Variable(high_comb[np.newaxis, :]).cuda()
                # imshow(high_comb.cpu().data)
                input_comb = torch.cat([
                    scale(input1),
                    scale(input2),
                    scale(input3),
                    scale(input4)
                ], 0)
                input_comb = input_comb[np.newaxis, :]
                if opt.cuda:
                    if phase == 'train':
                        optimizer.zero_grad()
                    high_res_real = Variable(high_img.cuda())
                    high_res_fake = generator(Variable(input_comb).cuda())

                    outputs = torch.chunk(high_res_fake, 4, 1)
                    outputs = torch.cat(
                        [outputs[0], outputs[1], outputs[2], outputs[3]], 0)
                    # imshow(outputs[0])
                    generator_content_loss = content_criterion(
                        high_res_fake, high_comb)
                    mean_generator_content_loss += generator_content_loss.data[
                        0]

                    generator_total_loss = generator_content_loss
                    mean_generator_total_loss += generator_total_loss.data[0]

                    if phase == 'train':
                        generator_total_loss.backward()
                        optimizer.step()

                    if (batch_no % 10 == 0):
                        sys.stdout.write(
                            '\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f '
                            % (phase, epoch, opt.nEpochs, batch_no,
                               len(dataloader[phase]), generator_content_loss))

            if phase == 'train':
                # imsave(outputs,train=True,epoch=epoch,image_type='fake')
                # imsave(high_img, train=True, epoch=epoch, image_type='real')
                # imsave(input_comb, train=True, epoch=epoch, image_type='low')
                writer.add_scalar(phase + " per epoch/generator lr",
                                  optimizer.param_groups[0]['lr'], epoch + 1)
                scheduler_gen.step(mean_generator_total_loss /
                                   len(dataloader[phase]))

            mssim = avg_msssim(high_res_real, outputs)
            psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs))

            writer.add_scalar(phase + " per epoch/PSNR", psnr_val, epoch + 1)
            writer.add_scalar(
                phase + " per epoch/generator loss",
                mean_generator_total_loss / len(dataloader[phase]), epoch + 1)
            writer.add_scalar("per epoch/total time taken",
                              time.time() - curr_time, epoch + 1)
            writer.add_scalar(phase + " per epoch/avg_mssim", mssim, epoch + 1)

            torch.save(generator.state_dict(),
                       '%s/generator_firstfinal.pth' % opt.out)
Ejemplo n.º 6
0
def train_multiple(generator, discriminator, opt, dataloader, writer, scale):
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()
    aesthetic_loss = AestheticLoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    optimizer = optim.Adam(generator.parameters(), lr=opt.generatorLR)
    optim_discriminator = optim.Adam(
        discriminator.parameters(), lr=opt.discriminatorLR)
    scheduler_gen = ReduceLROnPlateau(
        optimizer, 'min', factor=0.7, patience=10, verbose=True)
    scheduler_dis = ReduceLROnPlateau(
        optim_discriminator, 'min', factor=0.7, patience=10, verbose=True)
    curr_time = time.time()
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    # pretraining
    for epoch in range(2):
        mean_generator_content_loss = 0.0

        inputs = torch.FloatTensor(
            opt.batchSize, 3, opt.imageSize, opt.imageSize)

        for batch_no, data in enumerate(dataloader['train']):
            high_img, _ = data
            # save_image(high_img, "test.png")
            # time.sleep(10)

            for j in range(opt.batchSize):
                inputs[j] = scale(high_img[j])
                high_img[j] = normalize(high_img[j])

            # print(high_img[0].shape)
            # print(inputs[0].shape)
            # time.sleep(10)

            if opt.cuda:
                optimizer.zero_grad()
                high_res_real = Variable(high_img.cuda())
                high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(
                ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda())

                generator_content_loss = content_criterion(
                    high_res_fake, high_res_real)

                mean_generator_content_loss += generator_content_loss.data.item()

                generator_content_loss.backward()
                optimizer.step()

                sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (
                    epoch, 2, batch_no, len(dataloader['train']), generator_content_loss.data.item()))

    # training
    for epoch in range(opt.nEpochs):
        for phase in ['train', 'test']:
            if phase == 'test':
                generator.train(False)
                discriminator.train(False)
            else:
                generator.train(True)
                discriminator.train(True)

            mean_generator_content_loss = 0.0
            mean_generator_adversarial_loss = 0.0
            mean_generator_total_loss = 0.0
            mean_discriminator_loss = 0.0
            # mean_psnr = 0.0
            # mean_msssim = 0.0
            high_img = torch.FloatTensor(
                opt.batchSize, 3, opt.imageSize, opt.imageSize)
            inputs = torch.FloatTensor(
                opt.batchSize, 3, opt.imageSize, opt.imageSize)

            for batch_no, data in enumerate(dataloader[phase]):
                high_img, _ = data

                for j in range(opt.batchSize):
                    inputs[j] = scale(high_img[j])
                    high_img[j] = normalize(high_img[j])

                if opt.cuda:
                    optimizer.zero_grad()
                    high_res_real = Variable(high_img.cuda())
                    high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(
                    ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda())

                    # save_image(high_res_real, "REAL.png")
                    # save_image(high_res_fake, "FAKE.png")

                    target_real = Variable(torch.rand(
                        opt.batchSize, 1) * 0.5 + 0.7).cuda()
                    target_fake = Variable(torch.rand(
                        opt.batchSize, 1) * 0.3).cuda()

                    discriminator.zero_grad()

                    discriminator_loss = adversarial_criterion(
                        discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(),
                                      Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()),
                        target_real) + \
                        adversarial_criterion(
                        discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :],
                                      high_res_fake[3][np.newaxis, :]), target_fake)
                    mean_discriminator_loss += discriminator_loss.data.item()

                    if phase == 'train':
                        discriminator_loss.backward(retain_graph=True)
                        optim_discriminator.step()

                    #high_res_fake_cat = torch.cat([ image for image in high_res_fake ], 0)
                    fake_features = feature_extractor(high_res_fake)
                    real_features = Variable(
                        feature_extractor(high_res_real).data)

                    # generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
                    generator_content_loss = content_criterion(high_res_fake,
                                                               high_res_real) + content_criterion(fake_features,
                                                                                                  real_features)
                    mean_generator_content_loss += generator_content_loss.data.item()
                    generator_adversarial_loss = adversarial_criterion(discriminator(
                        high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), ones_const)
                    mean_generator_adversarial_loss += generator_adversarial_loss.data.item()

                    generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
                    mean_generator_total_loss += generator_total_loss.data.item()

                    if phase == 'train':
                        generator_total_loss.backward()
                        optimizer.step()

                    if(batch_no % 10 == 0):
                        # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss))
                        sys.stdout.write('\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f' % (
                            phase, epoch, opt.nEpochs, batch_no, len(dataloader[phase]), generator_content_loss, discriminator_loss))

            if phase == 'train':
                imsave(high_res_fake.cpu().data, train=True,
                       epoch=epoch, image_type='fake')
                imsave(high_img, train=True, epoch=epoch, image_type='real')
                imsave(inputs, train=True, epoch=epoch, image_type='low')
                writer.add_scalar(phase + " per epoch/generator lr",
                                  optimizer.param_groups[0]['lr'], epoch + 1)
                writer.add_scalar(phase + " per epoch/discriminator lr", optim_discriminator.param_groups[0]['lr'],
                                  epoch + 1)
                scheduler_gen.step(
                    mean_generator_total_loss / len(dataloader[phase]))
                scheduler_dis.step(mean_discriminator_loss /
                                   len(dataloader[phase]))
            else:
                imsave(high_res_fake.cpu().data, train=False,
                       epoch=epoch, image_type='fake')
                imsave(high_img, train=False, epoch=epoch, image_type='real')
                imsave(inputs, train=False, epoch=epoch, image_type='low')
            # import ipdb;
            # ipdb.set_trace()
            mssim = avg_msssim(high_res_real, high_res_fake)
            psnr_val = psnr(un_normalize(high_res_real),
                            un_normalize(high_res_fake))

            writer.add_scalar(phase + " per epoch/PSNR", psnr_val,
                              epoch + 1)
            writer.add_scalar(phase+" per epoch/discriminator loss",
                              mean_discriminator_loss/len(dataloader[phase]), epoch+1)
            writer.add_scalar(phase+" per epoch/generator loss",
                              mean_generator_total_loss/len(dataloader[phase]), epoch+1)
            writer.add_scalar("per epoch/total time taken",
                              time.time()-curr_time, epoch+1)
            writer.add_scalar(phase+" per epoch/avg_mssim", mssim, epoch+1)
        # Do checkpointing
        torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out)
        torch.save(discriminator.state_dict(),
                   '%s/discriminator_final.pth' % opt.out)
def test_firstmodel(generator, opt, dataloader, writer, scale):
    content_criterion = nn.MSELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        content_criterion.cuda()

    curr_time = time.time()

    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_total_loss = 0.0

        high_res_fake = 0

        for batch_no, data in enumerate(dataloader['test']):
            high_img, _ = data
            generator.train(False)

            input1 = high_img[0, :, :, :]
            input2 = high_img[1, :, :, :]
            input3 = high_img[2, :, :, :]
            input4 = high_img[3, :, :, :]
            # imshow(input3)

            for j in range(opt.batchSize):
                high_img[j] = normalize(high_img[j])
            high_comb = torch.cat(
                [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

            high_comb = Variable(high_comb[np.newaxis, :]).cuda()
            # imshow(high_comb.cpu().data)
            input_comb = torch.cat(
                [scale(input1),
                 scale(input2),
                 scale(input3),
                 scale(input4)], 0)
            input_comb = input_comb[np.newaxis, :]

            if opt.cuda:
                high_res_real = Variable(high_img.cuda())
                high_res_fake = generator(Variable(input_comb).cuda())

                outputs = torch.chunk(high_res_fake, 4, 1)
                outputs = torch.cat(
                    [outputs[0], outputs[1], outputs[2], outputs[3]], 0)
                # imshow(outputs[0])
                generator_content_loss = content_criterion(
                    high_res_fake, high_comb)
                mean_generator_content_loss += generator_content_loss.data[0]

                generator_total_loss = generator_content_loss
                mean_generator_total_loss += generator_total_loss.data[0]

                if (batch_no % 10 == 0):
                    #                         # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss))
                    sys.stdout.write(
                        '\r epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f '
                        % (epoch, opt.nEpochs, batch_no, len(
                            dataloader['test']), generator_content_loss))

        mssim = avg_msssim(high_res_real, outputs)
        psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs))

        writer.add_scalar("test per epoch/PSNR", psnr_val, epoch + 1)
        # writer.add_scalar(phase+" per epoch/discriminator loss", mean_discriminator_loss/len(dataloader[phase]), epoch+1)
        writer.add_scalar("test per epoch/generator loss",
                          mean_generator_total_loss / len(dataloader[phase]),
                          epoch + 1)
        writer.add_scalar("per epoch/total time taken",
                          time.time() - curr_time, epoch + 1)
        writer.add_scalar("test per epoch/avg_mssim", mssim, epoch + 1)

        torch.save(generator.state_dict(),
                   '%s/generator_firstfinal.pth' % opt.out)
def test_single(generator, discriminator, opt, dataloader, scale):
    generator.load_state_dict(torch.load(opt.generatorWeights))
    discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    curr_time = time.time()

    # for epoch in range(opt.nEpochs):
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0
    high_res_fake = 0
    for batch_no, data in enumerate(dataloader['test']):
        high_img, _ = data
        generator.train(False)
        discriminator.train(False)

        print("batch no. {} shape of input {}".format(batch_no,
                                                      high_img.shape))
        input1 = high_img[0, :, :, :]
        input2 = high_img[1, :, :, :]
        input3 = high_img[2, :, :, :]
        input4 = high_img[3, :, :, :]
        inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                   opt.imageSize)
        # imshow(input3)
        for j in range(opt.batchSize):
            inputs[j] = scale(high_img[j])
            high_img[j] = normalize(high_img[j])
        high_comb = torch.cat(
            [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

        high_comb = Variable(high_comb[np.newaxis, :]).cuda()
        # imshow(high_comb.cpu().data)
        input_comb = torch.cat(
            [scale(input1),
             scale(input2),
             scale(input3),
             scale(input4)], 0)
        # inputs = [scale(input1), scale(input2), scale(input3), scale(input4)]
        input_comb = input_comb[np.newaxis, :]
        if opt.cuda:
            # optimizer.zero_grad()
            high_res_real = Variable(high_img.cuda())
            high_res_fake = generator(Variable(input_comb).cuda())
            target_real = Variable(torch.rand(1, 1) * 0.5 + 0.7).cuda()
            target_fake = Variable(torch.rand(1, 1) * 0.3).cuda()

            outputs = torch.chunk(high_res_fake, 4, 1)
            outputs = torch.cat(
                [outputs[0], outputs[1], outputs[2], outputs[3]], 0)
            # imshow(outputs[0])
            generator_content_loss = content_criterion(high_res_fake,
                                                       high_comb)
            mean_generator_content_loss += generator_content_loss.data[0]
            generator_adversarial_loss = adversarial_criterion(
                discriminator(high_res_fake), ones_const)
            mean_generator_adversarial_loss += generator_adversarial_loss.data[
                0]

            generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
            mean_generator_total_loss += generator_total_loss.data[0]

            discriminator_loss = adversarial_criterion(discriminator(high_comb), target_real) + \
                                 adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake)
            mean_discriminator_loss += discriminator_loss.data[0]
            psnr_val = psnr(un_normalize(outputs), un_normalize(high_res_real))
            print(psnr_val)

            imsave(outputs.cpu().data,
                   train=False,
                   epoch=batch_no,
                   image_type='fake')
            imsave(high_img, train=False, epoch=batch_no, image_type='real')
            imsave(inputs, train=False, epoch=batch_no, image_type='low')