def test_multiple(generator, discriminator, opt, dataloader, scale): generator.load_state_dict(torch.load(opt.generatorWeights)) discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda() curr_time = time.time() inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 mean_psnr = 0.0 mean_msssim = 0.0 high_img = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) high_res_fake = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) max_psnr = 0.0 mean_psnr = 0.0 min_psnr = 999.0 mean_ssim = 0.0 for batch_no, data in enumerate(dataloader['test']): high_img, _ = data generator.train(False) discriminator.train(False) for j in range(opt.batchSize): inputs[j] = scale(high_img[j]) high_img[j] = normalize(high_img[j]) if opt.cuda: high_res_real = Variable(high_img.cuda()) high_res_fake = generator( Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()) target_real = Variable(torch.rand(opt.batchSize, 1) * 0.5 + 0.7).cuda() target_fake = Variable(torch.rand(opt.batchSize, 1) * 0.3).cuda() discriminator_loss = adversarial_criterion( discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()), target_real) + \ adversarial_criterion( discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), target_fake) mean_discriminator_loss += discriminator_loss.data.item() #high_res_fake_cat = torch.cat([image for image in high_res_fake], 0) fake_features = feature_extractor(high_res_fake) real_features = Variable(feature_extractor(high_res_real).data) generator_content_loss = content_criterion( high_res_fake, high_res_real) + 0.006 * content_criterion( fake_features, real_features) mean_generator_content_loss += generator_content_loss.data.item() generator_adversarial_loss = adversarial_criterion( discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), ones_const) mean_generator_adversarial_loss += generator_adversarial_loss.data.item( ) generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data.item() imsave(high_res_fake.cpu().data, train=False, epoch=batch_no, image_type='fake') imsave(high_img, train=False, epoch=batch_no, image_type='real') imsave(inputs, train=False, epoch=batch_no, image_type='low') mssim = avg_msssim(high_res_real, high_res_fake) ssim = pytorch_ssim.ssim(high_res_fake, high_res_real).data.item() mean_ssim += ssim psnr_val = psnr(un_normalize(high_res_fake), un_normalize(high_res_real)) mean_psnr += psnr_val max_psnr = psnr_val if psnr_val > max_psnr else max_psnr min_psnr = psnr_val if psnr_val < min_psnr else min_psnr sys.stdout.write( '\rTesting batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f psnr %.4f ssim %.4f' % (batch_no, len(dataloader['test']), generator_content_loss, discriminator_loss, psnr_val, ssim)) print("Min psnr is: ", min_psnr) print("Mean psnr is: ", mean_psnr / 72) print("Max psnr is: ", max_psnr) print("Mean ssim is: ", mean_ssim / 72)
def train_multiple(generator, discriminator, opt, dataloader, writer, scale): feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() aesthetic_loss = AestheticLoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda() optimizer = optim.Adam(generator.parameters(), lr=opt.generatorLR) optim_discriminator = optim.Adam( discriminator.parameters(), lr=opt.discriminatorLR) scheduler_gen = ReduceLROnPlateau( optimizer, 'min', factor=0.7, patience=10, verbose=True) scheduler_dis = ReduceLROnPlateau( optim_discriminator, 'min', factor=0.7, patience=10, verbose=True) curr_time = time.time() inputs = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) # pretraining for epoch in range(2): mean_generator_content_loss = 0.0 inputs = torch.FloatTensor( opt.batchSize, 3, opt.imageSize, opt.imageSize) for batch_no, data in enumerate(dataloader['train']): high_img, _ = data # save_image(high_img, "test.png") # time.sleep(10) for j in range(opt.batchSize): inputs[j] = scale(high_img[j]) high_img[j] = normalize(high_img[j]) # print(high_img[0].shape) # print(inputs[0].shape) # time.sleep(10) if opt.cuda: optimizer.zero_grad() high_res_real = Variable(high_img.cuda()) high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda( ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()) generator_content_loss = content_criterion( high_res_fake, high_res_real) mean_generator_content_loss += generator_content_loss.data.item() generator_content_loss.backward() optimizer.step() sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % ( epoch, 2, batch_no, len(dataloader['train']), generator_content_loss.data.item())) # training for epoch in range(opt.nEpochs): for phase in ['train', 'test']: if phase == 'test': generator.train(False) discriminator.train(False) else: generator.train(True) discriminator.train(True) mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 # mean_psnr = 0.0 # mean_msssim = 0.0 high_img = torch.FloatTensor( opt.batchSize, 3, opt.imageSize, opt.imageSize) inputs = torch.FloatTensor( opt.batchSize, 3, opt.imageSize, opt.imageSize) for batch_no, data in enumerate(dataloader[phase]): high_img, _ = data for j in range(opt.batchSize): inputs[j] = scale(high_img[j]) high_img[j] = normalize(high_img[j]) if opt.cuda: optimizer.zero_grad() high_res_real = Variable(high_img.cuda()) high_res_fake = generator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda( ), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()) # save_image(high_res_real, "REAL.png") # save_image(high_res_fake, "FAKE.png") target_real = Variable(torch.rand( opt.batchSize, 1) * 0.5 + 0.7).cuda() target_fake = Variable(torch.rand( opt.batchSize, 1) * 0.3).cuda() discriminator.zero_grad() discriminator_loss = adversarial_criterion( discriminator(Variable(inputs[0][np.newaxis, :]).cuda(), Variable(inputs[1][np.newaxis, :]).cuda(), Variable(inputs[2][np.newaxis, :]).cuda(), Variable(inputs[3][np.newaxis, :]).cuda()), target_real) + \ adversarial_criterion( discriminator(high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), target_fake) mean_discriminator_loss += discriminator_loss.data.item() if phase == 'train': discriminator_loss.backward(retain_graph=True) optim_discriminator.step() #high_res_fake_cat = torch.cat([ image for image in high_res_fake ], 0) fake_features = feature_extractor(high_res_fake) real_features = Variable( feature_extractor(high_res_real).data) # generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features) generator_content_loss = content_criterion(high_res_fake, high_res_real) + content_criterion(fake_features, real_features) mean_generator_content_loss += generator_content_loss.data.item() generator_adversarial_loss = adversarial_criterion(discriminator( high_res_fake[0][np.newaxis, :], high_res_fake[1][np.newaxis, :], high_res_fake[2][np.newaxis, :], high_res_fake[3][np.newaxis, :]), ones_const) mean_generator_adversarial_loss += generator_adversarial_loss.data.item() generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data.item() if phase == 'train': generator_total_loss.backward() optimizer.step() if(batch_no % 10 == 0): # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss)) sys.stdout.write('\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f discriminator_loss %.4f' % ( phase, epoch, opt.nEpochs, batch_no, len(dataloader[phase]), generator_content_loss, discriminator_loss)) if phase == 'train': imsave(high_res_fake.cpu().data, train=True, epoch=epoch, image_type='fake') imsave(high_img, train=True, epoch=epoch, image_type='real') imsave(inputs, train=True, epoch=epoch, image_type='low') writer.add_scalar(phase + " per epoch/generator lr", optimizer.param_groups[0]['lr'], epoch + 1) writer.add_scalar(phase + " per epoch/discriminator lr", optim_discriminator.param_groups[0]['lr'], epoch + 1) scheduler_gen.step( mean_generator_total_loss / len(dataloader[phase])) scheduler_dis.step(mean_discriminator_loss / len(dataloader[phase])) else: imsave(high_res_fake.cpu().data, train=False, epoch=epoch, image_type='fake') imsave(high_img, train=False, epoch=epoch, image_type='real') imsave(inputs, train=False, epoch=epoch, image_type='low') # import ipdb; # ipdb.set_trace() mssim = avg_msssim(high_res_real, high_res_fake) psnr_val = psnr(un_normalize(high_res_real), un_normalize(high_res_fake)) writer.add_scalar(phase + " per epoch/PSNR", psnr_val, epoch + 1) writer.add_scalar(phase+" per epoch/discriminator loss", mean_discriminator_loss/len(dataloader[phase]), epoch+1) writer.add_scalar(phase+" per epoch/generator loss", mean_generator_total_loss/len(dataloader[phase]), epoch+1) writer.add_scalar("per epoch/total time taken", time.time()-curr_time, epoch+1) writer.add_scalar(phase+" per epoch/avg_mssim", mssim, epoch+1) # Do checkpointing torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out) torch.save(discriminator.state_dict(), '%s/discriminator_final.pth' % opt.out)
def train_firstmodel(generator, opt, dataloader, writer, scale): content_criterion = nn.MSELoss() ones_const = Variable(torch.ones(1, 1)) if opt.cuda: generator.cuda() content_criterion.cuda() optimizer = optim.SGD(generator.parameters(), lr=opt.generatorLR) scheduler_gen = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True) curr_time = time.time() for epoch in range(opt.nEpochs): mean_generator_content_loss = 0.0 mean_generator_total_loss = 0.0 high_res_fake = 0 for phase in ['train', 'test']: if phase == 'test': generator.train(False) else: generator.train(True) for batch_no, data in enumerate(dataloader[phase]): high_img, _ = data input1 = high_img[0, :, :, :] input2 = high_img[1, :, :, :] input3 = high_img[2, :, :, :] input4 = high_img[3, :, :, :] # imshow(input3) for j in range(opt.batchSize): high_img[j] = normalize(high_img[j]) high_comb = torch.cat( [high_img[0], high_img[1], high_img[2], high_img[3]], 0) high_comb = Variable(high_comb[np.newaxis, :]).cuda() # imshow(high_comb.cpu().data) input_comb = torch.cat([ scale(input1), scale(input2), scale(input3), scale(input4) ], 0) input_comb = input_comb[np.newaxis, :] if opt.cuda: if phase == 'train': optimizer.zero_grad() high_res_real = Variable(high_img.cuda()) high_res_fake = generator(Variable(input_comb).cuda()) outputs = torch.chunk(high_res_fake, 4, 1) outputs = torch.cat( [outputs[0], outputs[1], outputs[2], outputs[3]], 0) # imshow(outputs[0]) generator_content_loss = content_criterion( high_res_fake, high_comb) mean_generator_content_loss += generator_content_loss.data[ 0] generator_total_loss = generator_content_loss mean_generator_total_loss += generator_total_loss.data[0] if phase == 'train': generator_total_loss.backward() optimizer.step() if (batch_no % 10 == 0): sys.stdout.write( '\rphase [%s] epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f ' % (phase, epoch, opt.nEpochs, batch_no, len(dataloader[phase]), generator_content_loss)) if phase == 'train': # imsave(outputs,train=True,epoch=epoch,image_type='fake') # imsave(high_img, train=True, epoch=epoch, image_type='real') # imsave(input_comb, train=True, epoch=epoch, image_type='low') writer.add_scalar(phase + " per epoch/generator lr", optimizer.param_groups[0]['lr'], epoch + 1) scheduler_gen.step(mean_generator_total_loss / len(dataloader[phase])) mssim = avg_msssim(high_res_real, outputs) psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs)) writer.add_scalar(phase + " per epoch/PSNR", psnr_val, epoch + 1) writer.add_scalar( phase + " per epoch/generator loss", mean_generator_total_loss / len(dataloader[phase]), epoch + 1) writer.add_scalar("per epoch/total time taken", time.time() - curr_time, epoch + 1) writer.add_scalar(phase + " per epoch/avg_mssim", mssim, epoch + 1) torch.save(generator.state_dict(), '%s/generator_firstfinal.pth' % opt.out)
def test_firstmodel(generator, opt, dataloader, writer, scale): content_criterion = nn.MSELoss() ones_const = Variable(torch.ones(1, 1)) if opt.cuda: generator.cuda() content_criterion.cuda() curr_time = time.time() for epoch in range(opt.nEpochs): mean_generator_content_loss = 0.0 mean_generator_total_loss = 0.0 high_res_fake = 0 for batch_no, data in enumerate(dataloader['test']): high_img, _ = data generator.train(False) input1 = high_img[0, :, :, :] input2 = high_img[1, :, :, :] input3 = high_img[2, :, :, :] input4 = high_img[3, :, :, :] # imshow(input3) for j in range(opt.batchSize): high_img[j] = normalize(high_img[j]) high_comb = torch.cat( [high_img[0], high_img[1], high_img[2], high_img[3]], 0) high_comb = Variable(high_comb[np.newaxis, :]).cuda() # imshow(high_comb.cpu().data) input_comb = torch.cat( [scale(input1), scale(input2), scale(input3), scale(input4)], 0) input_comb = input_comb[np.newaxis, :] if opt.cuda: high_res_real = Variable(high_img.cuda()) high_res_fake = generator(Variable(input_comb).cuda()) outputs = torch.chunk(high_res_fake, 4, 1) outputs = torch.cat( [outputs[0], outputs[1], outputs[2], outputs[3]], 0) # imshow(outputs[0]) generator_content_loss = content_criterion( high_res_fake, high_comb) mean_generator_content_loss += generator_content_loss.data[0] generator_total_loss = generator_content_loss mean_generator_total_loss += generator_total_loss.data[0] if (batch_no % 10 == 0): # # print("phase {} batch no. {} generator_content_loss {} discriminator_loss {}".format(phase, batch_no, generator_content_loss, discriminator_loss)) sys.stdout.write( '\r epoch [%d/%d] batch no. [%d/%d] Generator_content_Loss: %.4f ' % (epoch, opt.nEpochs, batch_no, len( dataloader['test']), generator_content_loss)) mssim = avg_msssim(high_res_real, outputs) psnr_val = psnr(un_normalize(high_res_real), un_normalize(outputs)) writer.add_scalar("test per epoch/PSNR", psnr_val, epoch + 1) # writer.add_scalar(phase+" per epoch/discriminator loss", mean_discriminator_loss/len(dataloader[phase]), epoch+1) writer.add_scalar("test per epoch/generator loss", mean_generator_total_loss / len(dataloader[phase]), epoch + 1) writer.add_scalar("per epoch/total time taken", time.time() - curr_time, epoch + 1) writer.add_scalar("test per epoch/avg_mssim", mssim, epoch + 1) torch.save(generator.state_dict(), '%s/generator_firstfinal.pth' % opt.out)