Beispiel #1
0
def test_multiple(generator, discriminator, opt, dataloader, scale):
    generator.load_state_dict(torch.load(opt.generatorWeights))
    discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    curr_time = time.time()
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0
    mean_psnr = 0.0
    mean_msssim = 0.0
    high_img = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                 opt.imageSize)
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
    high_res_fake = torch.FloatTensor(opt.batchSize, 3, opt.imageSize,
                                      opt.imageSize)
    max_psnr = 0.0
    mean_psnr = 0.0
    min_psnr = 999.0
    mean_ssim = 0.0
    for batch_no, data in enumerate(dataloader['test']):
        high_img, _ = data
        generator.train(False)
        discriminator.train(False)
        for j in range(opt.batchSize):
            inputs[j] = scale(high_img[j])
            high_img[j] = normalize(high_img[j])

        if opt.cuda:
            high_res_real = Variable(high_img.cuda())
            high_res_fake = generator(
                Variable(inputs[0][np.newaxis, :]).cuda(),
                Variable(inputs[1][np.newaxis, :]).cuda(),
                Variable(inputs[2][np.newaxis, :]).cuda(),
                Variable(inputs[3][np.newaxis, :]).cuda())
            target_real = Variable(torch.rand(opt.batchSize, 1) * 0.5 +
                                   0.7).cuda()
            target_fake = Variable(torch.rand(opt.batchSize, 1) * 0.3).cuda()

            discriminator_loss = adversarial_criterion(
                discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(),
                              Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()),
                target_real) + \
                adversarial_criterion(
                discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :],
                              high_res_fake[3][np.newaxis, :]), target_fake)
            mean_discriminator_loss += discriminator_loss.data.item()

            #high_res_fake_cat = torch.cat([image for image in high_res_fake], 0)
            fake_features = feature_extractor(high_res_fake)
            real_features = Variable(feature_extractor(high_res_real).data)

            generator_content_loss = content_criterion(
                high_res_fake, high_res_real) + 0.006 * content_criterion(
                    fake_features, real_features)
            mean_generator_content_loss += generator_content_loss.data.item()
            generator_adversarial_loss = adversarial_criterion(
                discriminator(high_res_fake[0][np.newaxis, :],
                              high_res_fake[1][np.newaxis, :],
                              high_res_fake[2][np.newaxis, :],
                              high_res_fake[3][np.newaxis, :]), ones_const)
            mean_generator_adversarial_loss += generator_adversarial_loss.data.item(
            )

            generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
            mean_generator_total_loss += generator_total_loss.data.item()
        imsave(high_res_fake.cpu().data,
               train=False,
               epoch=batch_no,
               image_type='fake')
        imsave(high_img, train=False, epoch=batch_no, image_type='real')
        imsave(inputs, train=False, epoch=batch_no, image_type='low')

        mssim = avg_msssim(high_res_real, high_res_fake)
        ssim = pytorch_ssim.ssim(high_res_fake, high_res_real).data.item()
        mean_ssim += ssim
        psnr_val = psnr(un_normalize(high_res_fake),
                        un_normalize(high_res_real))
        mean_psnr += psnr_val
        max_psnr = psnr_val if psnr_val > max_psnr else max_psnr
        min_psnr = psnr_val if psnr_val < min_psnr else min_psnr
        sys.stdout.write(
            '\rTesting batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f psnr %.4f ssim %.4f'
            % (batch_no, len(dataloader['test']), generator_content_loss,
               discriminator_loss, psnr_val, ssim))
    print("Min psnr is: ", min_psnr)
    print("Mean psnr is: ", mean_psnr / 72)
    print("Max psnr is: ", max_psnr)
    print("Mean ssim is: ", mean_ssim / 72)
def train_multiple(generator, discriminator, opt, dataloader, writer, scale):
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()
    aesthetic_loss = AestheticLoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    optimizer = optim.Adam(generator.parameters(), lr=opt.generatorLR)
    optim_discriminator = optim.Adam(
        discriminator.parameters(), lr=opt.discriminatorLR)
    scheduler_gen = ReduceLROnPlateau(
        optimizer, 'min', factor=0.7, patience=10, verbose=True)
    scheduler_dis = ReduceLROnPlateau(
        optim_discriminator, 'min', factor=0.7, patience=10, verbose=True)
    curr_time = time.time()
    inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    # pretraining
    for epoch in range(2):
        mean_generator_content_loss = 0.0

        inputs = torch.FloatTensor(
            opt.batchSize, 3, opt.imageSize, opt.imageSize)

        for batch_no, data in enumerate(dataloader['train']):
            high_img, _ = data
            # save_image(high_img, "test.png")
            # time.sleep(10)

            for j in range(opt.batchSize):
                inputs[j] = scale(high_img[j])
                high_img[j] = normalize(high_img[j])

            # print(high_img[0].shape)
            # print(inputs[0].shape)
            # time.sleep(10)

            if opt.cuda:
                optimizer.zero_grad()
                high_res_real = Variable(high_img.cuda())
                high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(
                ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda())

                generator_content_loss = content_criterion(
                    high_res_fake, high_res_real)

                mean_generator_content_loss += generator_content_loss.data.item()

                generator_content_loss.backward()
                optimizer.step()

                sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (
                    epoch, 2, batch_no, len(dataloader['train']), generator_content_loss.data.item()))

    # training
    for epoch in range(opt.nEpochs):
        for phase in ['train', 'test']:
            if phase == 'test':
                generator.train(False)
                discriminator.train(False)
            else:
                generator.train(True)
                discriminator.train(True)

            mean_generator_content_loss = 0.0
            mean_generator_adversarial_loss = 0.0
            mean_generator_total_loss = 0.0
            mean_discriminator_loss = 0.0
            # mean_psnr = 0.0
            # mean_msssim = 0.0
            high_img = torch.FloatTensor(
                opt.batchSize, 3, opt.imageSize, opt.imageSize)
            inputs = torch.FloatTensor(
                opt.batchSize, 3, opt.imageSize, opt.imageSize)

            for batch_no, data in enumerate(dataloader[phase]):
                high_img, _ = data

                for j in range(opt.batchSize):
                    inputs[j] = scale(high_img[j])
                    high_img[j] = normalize(high_img[j])

                if opt.cuda:
                    optimizer.zero_grad()
                    high_res_real = Variable(high_img.cuda())
                    high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(
                    ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda())

                    # save_image(high_res_real, "REAL.png")
                    # save_image(high_res_fake, "FAKE.png")

                    target_real = Variable(torch.rand(
                        opt.batchSize, 1) * 0.5 + 0.7).cuda()
                    target_fake = Variable(torch.rand(
                        opt.batchSize, 1) * 0.3).cuda()

                    discriminator.zero_grad()

                    discriminator_loss = adversarial_criterion(
                        discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(),
                                      Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()),
                        target_real) + \
                        adversarial_criterion(
                        discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :],
                                      high_res_fake[3][np.newaxis, :]), target_fake)
                    mean_discriminator_loss += discriminator_loss.data.item()

                    if phase == 'train':
                        discriminator_loss.backward(retain_graph=True)
                        optim_discriminator.step()

                    #high_res_fake_cat = torch.cat([ image for image in high_res_fake ], 0)
                    fake_features = feature_extractor(high_res_fake)
                    real_features = Variable(
                        feature_extractor(high_res_real).data)

                    # generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
                    generator_content_loss = content_criterion(high_res_fake,
                                                               high_res_real) + content_criterion(fake_features,
                                                                                                  real_features)
                    mean_generator_content_loss += generator_content_loss.data.item()
                    generator_adversarial_loss = adversarial_criterion(discriminator(
                        high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), ones_const)
                    mean_generator_adversarial_loss += generator_adversarial_loss.data.item()

                    generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
                    mean_generator_total_loss += generator_total_loss.data.item()

                    if phase == 'train':
                        generator_total_loss.backward()
                        optimizer.step()

                    if(batch_no % 10 == 0):
                        # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss))
                        sys.stdout.write('\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f' % (
                            phase, epoch, opt.nEpochs, batch_no, len(dataloader[phase]), generator_content_loss, discriminator_loss))

            if phase == 'train':
                imsave(high_res_fake.cpu().data, train=True,
                       epoch=epoch, image_type='fake')
                imsave(high_img, train=True, epoch=epoch, image_type='real')
                imsave(inputs, train=True, epoch=epoch, image_type='low')
                writer.add_scalar(phase + " per epoch/generator lr",
                                  optimizer.param_groups[0]['lr'], epoch + 1)
                writer.add_scalar(phase + " per epoch/discriminator lr", optim_discriminator.param_groups[0]['lr'],
                                  epoch + 1)
                scheduler_gen.step(
                    mean_generator_total_loss / len(dataloader[phase]))
                scheduler_dis.step(mean_discriminator_loss /
                                   len(dataloader[phase]))
            else:
                imsave(high_res_fake.cpu().data, train=False,
                       epoch=epoch, image_type='fake')
                imsave(high_img, train=False, epoch=epoch, image_type='real')
                imsave(inputs, train=False, epoch=epoch, image_type='low')
            # import ipdb;
            # ipdb.set_trace()
            mssim = avg_msssim(high_res_real, high_res_fake)
            psnr_val = psnr(un_normalize(high_res_real),
                            un_normalize(high_res_fake))

            writer.add_scalar(phase + " per epoch/PSNR", psnr_val,
                              epoch + 1)
            writer.add_scalar(phase+" per epoch/discriminator loss",
                              mean_discriminator_loss/len(dataloader[phase]), epoch+1)
            writer.add_scalar(phase+" per epoch/generator loss",
                              mean_generator_total_loss/len(dataloader[phase]), epoch+1)
            writer.add_scalar("per epoch/total time taken",
                              time.time()-curr_time, epoch+1)
            writer.add_scalar(phase+" per epoch/avg_mssim", mssim, epoch+1)
        # Do checkpointing
        torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out)
        torch.save(discriminator.state_dict(),
                   '%s/discriminator_final.pth' % opt.out)
def train_firstmodel(generator, opt, dataloader, writer, scale):
    content_criterion = nn.MSELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        content_criterion.cuda()

    optimizer = optim.SGD(generator.parameters(), lr=opt.generatorLR)
    scheduler_gen = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=0.5,
                                      patience=3,
                                      verbose=True)
    curr_time = time.time()

    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_total_loss = 0.0

        high_res_fake = 0
        for phase in ['train', 'test']:

            if phase == 'test':
                generator.train(False)
            else:
                generator.train(True)

            for batch_no, data in enumerate(dataloader[phase]):
                high_img, _ = data

                input1 = high_img[0, :, :, :]
                input2 = high_img[1, :, :, :]
                input3 = high_img[2, :, :, :]
                input4 = high_img[3, :, :, :]
                # imshow(input3)
                for j in range(opt.batchSize):
                    high_img[j] = normalize(high_img[j])
                high_comb = torch.cat(
                    [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

                high_comb = Variable(high_comb[np.newaxis, :]).cuda()
                # imshow(high_comb.cpu().data)
                input_comb = torch.cat([
                    scale(input1),
                    scale(input2),
                    scale(input3),
                    scale(input4)
                ], 0)
                input_comb = input_comb[np.newaxis, :]
                if opt.cuda:
                    if phase == 'train':
                        optimizer.zero_grad()
                    high_res_real = Variable(high_img.cuda())
                    high_res_fake = generator(Variable(input_comb).cuda())

                    outputs = torch.chunk(high_res_fake, 4, 1)
                    outputs = torch.cat(
                        [outputs[0], outputs[1], outputs[2], outputs[3]], 0)
                    # imshow(outputs[0])
                    generator_content_loss = content_criterion(
                        high_res_fake, high_comb)
                    mean_generator_content_loss += generator_content_loss.data[
                        0]

                    generator_total_loss = generator_content_loss
                    mean_generator_total_loss += generator_total_loss.data[0]

                    if phase == 'train':
                        generator_total_loss.backward()
                        optimizer.step()

                    if (batch_no % 10 == 0):
                        sys.stdout.write(
                            '\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f '
                            % (phase, epoch, opt.nEpochs, batch_no,
                               len(dataloader[phase]), generator_content_loss))

            if phase == 'train':
                # imsave(outputs,train=True,epoch=epoch,image_type='fake')
                # imsave(high_img, train=True, epoch=epoch, image_type='real')
                # imsave(input_comb, train=True, epoch=epoch, image_type='low')
                writer.add_scalar(phase + " per epoch/generator lr",
                                  optimizer.param_groups[0]['lr'], epoch + 1)
                scheduler_gen.step(mean_generator_total_loss /
                                   len(dataloader[phase]))

            mssim = avg_msssim(high_res_real, outputs)
            psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs))

            writer.add_scalar(phase + " per epoch/PSNR", psnr_val, epoch + 1)
            writer.add_scalar(
                phase + " per epoch/generator loss",
                mean_generator_total_loss / len(dataloader[phase]), epoch + 1)
            writer.add_scalar("per epoch/total time taken",
                              time.time() - curr_time, epoch + 1)
            writer.add_scalar(phase + " per epoch/avg_mssim", mssim, epoch + 1)

            torch.save(generator.state_dict(),
                       '%s/generator_firstfinal.pth' % opt.out)
def test_firstmodel(generator, opt, dataloader, writer, scale):
    content_criterion = nn.MSELoss()

    ones_const = Variable(torch.ones(1, 1))

    if opt.cuda:
        generator.cuda()
        content_criterion.cuda()

    curr_time = time.time()

    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_total_loss = 0.0

        high_res_fake = 0

        for batch_no, data in enumerate(dataloader['test']):
            high_img, _ = data
            generator.train(False)

            input1 = high_img[0, :, :, :]
            input2 = high_img[1, :, :, :]
            input3 = high_img[2, :, :, :]
            input4 = high_img[3, :, :, :]
            # imshow(input3)

            for j in range(opt.batchSize):
                high_img[j] = normalize(high_img[j])
            high_comb = torch.cat(
                [high_img[0], high_img[1], high_img[2], high_img[3]], 0)

            high_comb = Variable(high_comb[np.newaxis, :]).cuda()
            # imshow(high_comb.cpu().data)
            input_comb = torch.cat(
                [scale(input1),
                 scale(input2),
                 scale(input3),
                 scale(input4)], 0)
            input_comb = input_comb[np.newaxis, :]

            if opt.cuda:
                high_res_real = Variable(high_img.cuda())
                high_res_fake = generator(Variable(input_comb).cuda())

                outputs = torch.chunk(high_res_fake, 4, 1)
                outputs = torch.cat(
                    [outputs[0], outputs[1], outputs[2], outputs[3]], 0)
                # imshow(outputs[0])
                generator_content_loss = content_criterion(
                    high_res_fake, high_comb)
                mean_generator_content_loss += generator_content_loss.data[0]

                generator_total_loss = generator_content_loss
                mean_generator_total_loss += generator_total_loss.data[0]

                if (batch_no % 10 == 0):
                    #                         # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss))
                    sys.stdout.write(
                        '\r epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f '
                        % (epoch, opt.nEpochs, batch_no, len(
                            dataloader['test']), generator_content_loss))

        mssim = avg_msssim(high_res_real, outputs)
        psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs))

        writer.add_scalar("test per epoch/PSNR", psnr_val, epoch + 1)
        # writer.add_scalar(phase+" per epoch/discriminator loss", mean_discriminator_loss/len(dataloader[phase]), epoch+1)
        writer.add_scalar("test per epoch/generator loss",
                          mean_generator_total_loss / len(dataloader[phase]),
                          epoch + 1)
        writer.add_scalar("per epoch/total time taken",
                          time.time() - curr_time, epoch + 1)
        writer.add_scalar("test per epoch/avg_mssim", mssim, epoch + 1)

        torch.save(generator.state_dict(),
                   '%s/generator_firstfinal.pth' % opt.out)