Пример #1
0
def test_metrics(model, video_path=None, frames=None, output_folder=None):

    if video_path is not None and frames is None:
        frames, _ = extract_frames(video_path)

    total_ssim = 0
    total_psnr = 0
    stride = 30
    iters = 1 + (len(frames) - 3) // stride

    triplets = []
    for i in range(iters):
        tup = (frames[i * stride], frames[i * stride + 1],
               frames[i * stride + 2])
        triplets.append(tup)

    iters = len(triplets)

    for i in range(iters):
        x1, gt, x2 = triplets[i]
        pred = interpolate(model, x1, x2)
        if output_folder is not None:
            frame_path = join(output_folder, f'wiz_{i}.jpg')
            pred.save(frame_path)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1}/{iters} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Пример #2
0
def run_parallax_view_generation0(torchModel,
                                  t,
                                  inputDir,
                                  outputDir,
                                  netmode,
                                  numImages=-1,
                                  save_images=True):
    cam_interval = t
    t = numImages
    parallax_output_dir = outputDir

    if save_images and parallax_output_dir != None:
        im_output = os.path.join(parallax_output_dir, "images")
        json_output = os.path.join(parallax_output_dir, "psnr.json")
        makedirs(im_output, exist_ok=True)

    images = load_images(inputDir)

    if t == -1:
        t = len(images)

    input_images = []
    for w in range(0, t, cam_interval):
        input_images.append(images[w])

    parallax_view = generate_parallax_view(torchModel, t, cam_interval,
                                           input_images, netmode)

    worstPsnr = 999999999

    resultsDict = {}

    for index, view in enumerate(parallax_view):
        p = 0
        if index % cam_interval != 0:
            p = psnr(pil_to_tensor(view), pil_to_tensor(images[index])).item()
            if p < worstPsnr:
                worstPsnr = p
        resultsDict[index] = p

        if save_images and parallax_output_dir != None:
            view.save(join_paths(im_output, '{}.jpg'.format(index + 1)),
                      'JPEG',
                      quality=95)
            writeJson(json_output, resultsDict)

    return worstPsnr
Пример #3
0
def validate(epoch):
    print("===> Running validation...")
    ssmi = loss.SsimLoss()
    valid_loss, valid_ssmi, valid_psnr = 0, 0, 0
    iters = len(validation_data_loader)
    with torch.no_grad():
        for batch in validation_data_loader:
            input, target = batch[0].to(device), batch[1].to(device)
            output = model(input)
            valid_loss += loss_function(output, target).item()
            valid_ssmi -= ssmi(output, target).item()
            valid_psnr += psnr(output, target).item()
    valid_loss /= iters
    valid_ssmi /= iters
    valid_psnr /= iters
    board_writer.add_scalar('data/epoch_validation_loss', valid_loss, epoch)
    board_writer.add_scalar('data/epoch_ssmi', valid_ssmi, epoch)
    board_writer.add_scalar('data/epoch_psnr', valid_psnr, epoch)
    print("===> Validation loss: {:.4f}".format(valid_loss))
Пример #4
0
def test_linear_interp(validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for tup in validation_set.tuples:
        x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup]
        pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Пример #5
0
def test_on_validation_set(model, validation_set=None):

    if validation_set is None:
        validation_set = get_validation_set()

    total_ssim = 0
    total_psnr = 0
    iters = len(validation_set.tuples)

    crop = CenterCrop(config.CROP_SIZE)

    for i, tup in enumerate(validation_set.tuples):
        x1, gt, x2, = [crop(load_img(p)) for p in tup]
        pred = interpolate(model, x1, x2)
        gt = pil_to_tensor(gt)
        pred = pil_to_tensor(pred)
        total_ssim += ssim(pred, gt).item()
        total_psnr += psnr(pred, gt).item()
        print(f'#{i+1} done')

    avg_ssim = total_ssim / iters
    avg_psnr = total_psnr / iters

    print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')