Example #1
0
def test_metrics(model, video_path=None, frames=None, output_folder=None):

    if video_path is not None and frames is None:
        frames, _ = extract_frames(video_path)

    total_ssim = 0
    total_psnr = 0
    stride = 30
    iters = 1 + (len(frames) - 3) // stride

    triplets = []
    for i in range(iters):
        tup = (frames[i * stride], frames[i * stride + 1],
               frames[i * stride + 2])
        triplets.append(tup)

    iters = len(triplets)

    for i in range(iters):
        x1, gt, x2 = triplets[i]
        pred = interpolate(model, x1, x2)
        if output_folder is not None:
            frame_path = join(output_folder, f'wiz_{i}.jpg')
            pred.save(frame_path)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1}/{iters} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Example #2
0
def interpolate3toN(model, left, middle, right, seq_len):
    batch = torch.stack([
        torch.cat(
            [pil_to_tensor(left),
             pil_to_tensor(middle),
             pil_to_tensor(right)],
            dim=0)
    ],
                        dim=0)

    frame_channels, frame_height, frame_width = batch[0].shape
    frame_channels /= 3
    assert frame_channels == 3, "Only frames with 3 channels are supported"

    input_pad, output_pad = _get_padding_modules(frame_height, frame_width)

    if torch.cuda.is_available():
        batch = batch.cuda()
        input_pad = input_pad.cuda()
        output_pad = output_pad.cuda()
        model = model.cuda()

    batch = input_pad(batch)

    with torch.no_grad():
        output = model(batch, seq_len, True, False)

    output = output_pad(output)

    output = output.cpu().detach().numpy()

    output = output.reshape(seq_len - 3, 3, frame_height, frame_width)

    return [numpy_to_pil(x) for x in output]
Example #3
0
def visual_test(epoch):
    print("===> Running visual test...")
    for i, tup in enumerate(visual_test_set):
        result = interpolate(model, load_img(tup[0]), load_img(tup[2]))
        result = pil_to_tensor(result)
        tag = 'data/visual_test_{}'.format(i)
        board_writer.add_image(tag, result, epoch)
Example #4
0
def run_parallax_view_generation0(torchModel,
                                  t,
                                  inputDir,
                                  outputDir,
                                  netmode,
                                  numImages=-1,
                                  save_images=True):
    cam_interval = t
    t = numImages
    parallax_output_dir = outputDir

    if save_images and parallax_output_dir != None:
        im_output = os.path.join(parallax_output_dir, "images")
        json_output = os.path.join(parallax_output_dir, "psnr.json")
        makedirs(im_output, exist_ok=True)

    images = load_images(inputDir)

    if t == -1:
        t = len(images)

    input_images = []
    for w in range(0, t, cam_interval):
        input_images.append(images[w])

    parallax_view = generate_parallax_view(torchModel, t, cam_interval,
                                           input_images, netmode)

    worstPsnr = 999999999

    resultsDict = {}

    for index, view in enumerate(parallax_view):
        p = 0
        if index % cam_interval != 0:
            p = psnr(pil_to_tensor(view), pil_to_tensor(images[index])).item()
            if p < worstPsnr:
                worstPsnr = p
        resultsDict[index] = p

        if save_images and parallax_output_dir != None:
            view.save(join_paths(im_output, '{}.jpg'.format(index + 1)),
                      'JPEG',
                      quality=95)
            writeJson(json_output, resultsDict)

    return worstPsnr
Example #5
0
def interpolate_batch(model_, pil_frames):

    assert len(
        pil_frames) > 1, "Frames to be interpolated must be at least two"

    batch = []
    for i in range(len(pil_frames) - 1):
        frame1 = pil_to_tensor(pil_frames[i])
        frame2 = pil_to_tensor(pil_frames[i + 1])
        batch.append(torch.cat((frame1, frame2), dim=0))
    batch = torch.stack(batch, dim=0)

    frame_channels, frame_height, frame_width = batch[0].shape
    frame_channels /= 2
    assert frame_channels == 3, "Only frames with 3 channels are supported"

    # Generate the padding functions for the given input size
    input_pad, output_pad = _get_padding_modules(frame_height, frame_width)

    # Use CUDA if possible
    if torch.cuda.is_available():
        batch = batch.cuda()
        input_pad = input_pad.cuda()
        output_pad = output_pad.cuda()
        model_ = model_.cuda()

    # Apply input padding
    batch = input_pad(batch)

    # Run forward pass
    with torch.no_grad():
        output = model_(batch)

    # Apply output padding
    output = output_pad(output)

    # Get numpy representation of the output
    output = output.cpu().detach().numpy()

    output_pils = [numpy_to_pil(x) for x in output]
    return output_pils
Example #6
0
def test_on_validation_set(model, validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for i, tup in enumerate(validation_set.tuples):
        x1, gt, x2, = [crop(load_img(p)) for p in tup]
        pred = interpolate(model, x1, x2)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Example #7
0
def test_linear_interp(validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for tup in validation_set.tuples:
        x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup]
        pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')