def test_metrics(model, video_path=None, frames=None, output_folder=None): if video_path is not None and frames is None: frames, _ = extract_frames(video_path) total_ssim = 0 total_psnr = 0 stride = 30 iters = 1 + (len(frames) - 3) // stride triplets = [] for i in range(iters): tup = (frames[i * stride], frames[i * stride + 1], frames[i * stride + 2]) triplets.append(tup) iters = len(triplets) for i in range(iters): x1, gt, x2 = triplets[i] pred = interpolate(model, x1, x2) if output_folder is not None: frame_path = join(output_folder, f'wiz_{i}.jpg') pred.save(frame_path) gt = pil_to_tensor(gt) pred = pil_to_tensor(pred) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() print(f'#{i+1}/{iters} done') avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
def run_parallax_view_generation0(torchModel, t, inputDir, outputDir, netmode, numImages=-1, save_images=True): cam_interval = t t = numImages parallax_output_dir = outputDir if save_images and parallax_output_dir != None: im_output = os.path.join(parallax_output_dir, "images") json_output = os.path.join(parallax_output_dir, "psnr.json") makedirs(im_output, exist_ok=True) images = load_images(inputDir) if t == -1: t = len(images) input_images = [] for w in range(0, t, cam_interval): input_images.append(images[w]) parallax_view = generate_parallax_view(torchModel, t, cam_interval, input_images, netmode) worstPsnr = 999999999 resultsDict = {} for index, view in enumerate(parallax_view): p = 0 if index % cam_interval != 0: p = psnr(pil_to_tensor(view), pil_to_tensor(images[index])).item() if p < worstPsnr: worstPsnr = p resultsDict[index] = p if save_images and parallax_output_dir != None: view.save(join_paths(im_output, '{}.jpg'.format(index + 1)), 'JPEG', quality=95) writeJson(json_output, resultsDict) return worstPsnr
def validate(epoch): print("===> Running validation...") ssmi = loss.SsimLoss() valid_loss, valid_ssmi, valid_psnr = 0, 0, 0 iters = len(validation_data_loader) with torch.no_grad(): for batch in validation_data_loader: input, target = batch[0].to(device), batch[1].to(device) output = model(input) valid_loss += loss_function(output, target).item() valid_ssmi -= ssmi(output, target).item() valid_psnr += psnr(output, target).item() valid_loss /= iters valid_ssmi /= iters valid_psnr /= iters board_writer.add_scalar('data/epoch_validation_loss', valid_loss, epoch) board_writer.add_scalar('data/epoch_ssmi', valid_ssmi, epoch) board_writer.add_scalar('data/epoch_psnr', valid_psnr, epoch) print("===> Validation loss: {:.4f}".format(valid_loss))
def test_linear_interp(validation_set=None): if validation_set is None: validation_set = get_validation_set() total_ssim = 0 total_psnr = 0 iters = len(validation_set.tuples) crop = CenterCrop(config.CROP_SIZE) for tup in validation_set.tuples: x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup] pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
def test_on_validation_set(model, validation_set=None): if validation_set is None: validation_set = get_validation_set() total_ssim = 0 total_psnr = 0 iters = len(validation_set.tuples) crop = CenterCrop(config.CROP_SIZE) for i, tup in enumerate(validation_set.tuples): x1, gt, x2, = [crop(load_img(p)) for p in tup] pred = interpolate(model, x1, x2) gt = pil_to_tensor(gt) pred = pil_to_tensor(pred) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() print(f'#{i+1} done') avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')