Пример #1
0
def test():
    """
    Function to carry out the testing/validation loop for the Full Network for a single epoch.
    :return: None
    """
    running_recon_loss = 0.0
    running_vp_loss = 0.0

    model.eval()

    for batch_idx, (vp_diff, vid1, vid2) in enumerate(testloader):
        vp_diff = vp_diff.type(torch.FloatTensor).to(device)
        vid1, vid2 = vid1.to(device), vid2.to(device)
        img1, img2 = get_first_frame(vid1), get_first_frame(vid2)
        img1, img2 = img1.to(device), img2.to(device)

        with torch.no_grad():
            gen_v2, vp_est = model(vp_diff=vp_diff, vid1=vid1, img2=img2)

            # save videos
            convert_to_vid(tensor=vid1,
                           output_dir=output_video_dir,
                           batch_num=batch_idx + 1,
                           view=1,
                           item_type='input')
            convert_to_vid(tensor=vid2,
                           output_dir=output_video_dir,
                           batch_num=batch_idx + 1,
                           view=2,
                           item_type='input')
            convert_to_vid(tensor=gen_v2,
                           output_dir=output_video_dir,
                           batch_num=batch_idx + 1,
                           view=2,
                           item_type='output')
            export_vps(vp_gt=vp_diff,
                       vp_est=vp_est,
                       output_dir=output_video_dir,
                       batch_num=batch_idx + 1)

            # loss
            recon_loss = criterion(gen_v2, vid2)
            vp_loss = criterion(vp_est, vp_diff)
            loss = recon_loss + vp_loss

        running_recon_loss += recon_loss.item()
        running_vp_loss += vp_loss.item()
        if (batch_idx + 1) % 10 == 0:
            print('\tBatch {}/{} ReconLoss:{} VPLoss:{}'.format(
                batch_idx + 1, len(testloader), "{0:.5f}".format(recon_loss),
                "{0:.5f}".format(vp_loss)))

    print('Testing Complete ReconLoss:{} VPLoss:{}'.format(
        "{0:.5f}".format((running_recon_loss / len(testloader))),
        "{0:.5f}".format((running_vp_loss / len(testloader)))))
Пример #2
0
# print(x.requires_grad_())
# x = x.requires_grad_()
# print(x)

# x = [1,2,3,4]
# x.reverse()
# print(x)
#
# del x
# print(x)

import torch
from utils.modelIOFuncs import export_vps

vp_gt = torch.randn(4)
vp_est = torch.randn(4)

export_vps(vp_gt=vp_gt, vp_est=vp_est, output_dir='./', batch_num=1)


import torch

x = torch.zeros(3, 4, 5, 6)
y = torch.sum(x, dim=1, keepdim=True)

print(x.size())
print(y.size())