def visual_test(epoch): print("===> Running visual test...") for i, tup in enumerate(visual_test_set): result = interpolate(model, load_img(tup[0]), load_img(tup[2])) result = pil_to_tensor(result) tag = 'data/visual_test_{}'.format(i) board_writer.add_image(tag, result, epoch)
def test_metrics(model, video_path=None, frames=None, output_folder=None): if video_path is not None and frames is None: frames, _ = extract_frames(video_path) total_ssim = 0 total_psnr = 0 stride = 30 iters = 1 + (len(frames) - 3) // stride triplets = [] for i in range(iters): tup = (frames[i * stride], frames[i * stride + 1], frames[i * stride + 2]) triplets.append(tup) iters = len(triplets) for i in range(iters): x1, gt, x2 = triplets[i] pred = interpolate(model, x1, x2) if output_folder is not None: frame_path = join(output_folder, f'wiz_{i}.jpg') pred.save(frame_path) gt = pil_to_tensor(gt) pred = pil_to_tensor(pred) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() print(f'#{i+1}/{iters} done') avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
def generate_parallax_view(torchModel, t, cam_interval, cam_views, netmode): """ cam_views is expected to be an array of pil images returns an array of pil images """ if netmode == "2to1": output = [] for w in range(1, t + 1): if (w - 1) % cam_interval == 0: output.append(cam_views[(w - 1) // cam_interval]) else: output.append(None) while cam_interval > 1: r_dot = cam_interval // 2 for w in range(1, t - cam_interval + 1, cam_interval): output[w + r_dot - 1] = interpolate( torchModel, output[w - 1], output[w + cam_interval - 1]) cam_interval = r_dot return output else: result = [] seq_len = cam_interval * 2 + 1 for iv in range(0, len(cam_views) - 2, 2): result.append(cam_views[iv]) interpolations = interpolate3toN(torchModel, cam_views[iv], cam_views[iv + 1], cam_views[iv + 2], seq_len) for idx, interpolation in enumerate(interpolations): result.append(interpolation) if idx == cam_interval - 2: result.append(cam_views[iv + 1]) result.append(cam_views[-1]) return result
def test_on_validation_set(model, validation_set=None): if validation_set is None: validation_set = get_validation_set() total_ssim = 0 total_psnr = 0 iters = len(validation_set.tuples) crop = CenterCrop(config.CROP_SIZE) for i, tup in enumerate(validation_set.tuples): x1, gt, x2, = [crop(load_img(p)) for p in tup] pred = interpolate(model, x1, x2) gt = pil_to_tensor(gt) pred = pil_to_tensor(pred) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() print(f'#{i+1} done') avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
def interpolate(self, *args): return interpol.interpolate(self, *args)