def simpleTest(device, dataloader, generator, MSE_Loss, step, alpha, resultpath='result.jpg'): for i, (x2_target_image, x4_target_image, target_image, input_image) in enumerate(dataloader): if step == 1: target_image = x2_target_image.to(device) elif step == 2: target_image = x4_target_image.to(device) else: target_image = target_image.to(device) input_image = input_image.to(device) predicted_image = generator(input_image, step, alpha) mse_loss = MSE_Loss(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) psnr = 10 * log10(1. / mse_loss.item()) _ssim = ssim(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) ms_ssim = msssim(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) sys.stdout.write('\r [%d/%d] Test progress... PSNR: %6.4f' % (i, len(dataloader), psnr)) utils.save_image(0.5 * predicted_image + 0.5, resultpath) print('Image generated!')
def test(dataloader, generator, MSE_Loss, step, alpha): avg_psnr = 0 avg_ssim = 0 avg_msssim = 0 for i, (x2_target_image, x4_target_image, target_image, input_image) in enumerate(dataloader): if step == 1: target_image = x2_target_image.to(device) elif step == 2: target_image = x4_target_image.to(device) else: target_image = target_image.to(device) input_image = input_image.to(device) predicted_image = generator(input_image, step, alpha) mse_loss = MSE_Loss(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) psnr = 10 * log10(1. / mse_loss.item()) avg_psnr += psnr _ssim = ssim(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) avg_ssim += _ssim.item() ms_ssim = msssim(0.5 * predicted_image + 0.5, 0.5 * target_image + 0.5) avg_msssim += ms_ssim.item() sys.stdout.write('\r [%d/%d] Test progress... PSNR: %6.4f' % (i, len(dataloader), psnr)) utils.save_image(0.5 * predicted_image + 0.5, os.path.join(args.result_path, '%d_results.jpg' % i)) print( 'Test done, Average PSNR:%6.4f, Average SSIM:%6.4f, Average MS-SSIM:%6.4f ' % (avg_psnr / len(dataloader), avg_ssim / len(dataloader), avg_msssim / len(dataloader)))
def forward(self, prediction, target): _prediction, _target = self.encoder(prediction), self.encoder(target) loss = ( self.alpha * F.mse_loss(_prediction, _target) + self.gamma * (1.0 - torch.mean(F.cosine_similarity(_prediction, _target, 1))) + self.beta * (1.0 - msssim(prediction, target, normalize=True))) return loss
def main(): parser = ArgumentParser('Calculate SSIM, MSSSIM') parser.add_argument('--meta_path', type=str) parser.add_argument('--img_dir', type=str) parser.add_argument('--pred_dir', type=str) args = parser.parse_args() meta_path = args.meta_path img_dir = args.img_dir pred_dir = args.pred_dir meta = json.load(open(meta_path, encoding="utf-8")) target = meta['target_chars'] target = [x for x in target] target_uni = [hex(ord(x))[2:].upper() for x in target] ssims = [] msssims = [] fonts = os.listdir(img_dir) for font in fonts: imgs = [] preds = [] for char in target_uni: img = Image.open(img_dir + '/' + font + '/uni' + char + '.png') img = transforms.ToTensor()(img) imgs.append(img) pred = Image.open(pred_dir + '/' + font + '/inferred_' + char + '.png') pred = transforms.ToTensor()(pred) preds.append(pred) # print(len(imgs), len(preds)) img_tensor = torch.stack(imgs).to(torch.device("cuda")) pred_tensor = torch.stack(preds).to(torch.device("cuda")) SSIM = ssim(img_tensor, pred_tensor) MSSSIM = msssim(img_tensor, pred_tensor) ssims.append(SSIM.item()) msssims.append(MSSSIM.item()) print(font, "SSIM:", SSIM.item(), "MSSSIM", MSSSIM.item()) print("AVERAGE SSIM:", sum(ssims)/len(ssims), "MSSSIM:", sum(msssims)/len(msssims))
def test(dataloader, generator, MSE_Loss, step, alpha): avg_psnr = 0 avg_ssim = 0 avg_msssim = 0 # https://stackoverflow.com/a/24658101 for i, (x2_target_image, x4_target_image, target_image, input_image) in enumerate(dataloader): input_image = input_image.to(device) if step==1: target_image = x2_target_image.to(device) elif step==2: target_image = x4_target_image.to(device) else: target_image = target_image.to(device) # define input image input_image = input_image.to(device) # make a prediction predicted_image = generator(input_image, step, alpha) predicted_image = predicted_image.double() # retrieve the original image and compute losses target_image = target_image.double() mse_loss = MSE_Loss(0.5*predicted_image+0.5, 0.5*target_image+0.5) psnr = 10*log10(1./mse_loss.item()) avg_psnr += psnr _ssim = ssim(0.5*predicted_image+0.5, 0.5*target_image+0.5) avg_ssim += _ssim.item() ms_ssim = msssim(0.5*predicted_image+0.5, 0.5*target_image+0.5) avg_msssim += ms_ssim.item() sys.stdout.write('\r [%d/%d] Test progress... PSNR: %6.4f'%(i, len(dataloader), psnr)) save_image = torch.cat([predicted_image, target_image], dim=0) if args.local_rank==0: utils.save_image(0.5*save_image+0.5, os.path.join(args.result_path, '%d_results.jpg'%i)) print('Test done, Average PSNR:%6.4f, Average SSIM:%6.4f, Average MS-SSIM:%6.4f '%(avg_psnr/len(dataloader),avg_ssim/len(dataloader), avg_msssim/len(dataloader)))
def forward(self, prediction, target): return (F.mse_loss(prediction, target) + self.alpha * (1.0 - msssim(prediction, target, normalize=self.normalize)))