示例#1
0
    def __init__(self, device='cpu'):
        super().__init__()
        self.flowComp = model.UNet(6, 4).to(device)
        self.ArbTimeFlowIntrp = model.UNet(20, 5).to(device)
        self.trainFlowBackWarp = model.backWarp(352, 352, device)

        vgg16 = torchvision.models.vgg16(pretrained=True)
        vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
        vgg16_conv_4_3.to(device)
        for param in vgg16_conv_4_3.parameters():
            param.requires_grad = False
        self.vgg16_conv_4_3 = vgg16_conv_4_3
示例#2
0
    def __init__(self, checkpoint, frame0, frame1, batch_size=1):
        # Initialize transforms
        self.checkpoint = checkpoint
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size

        mean = [0.429, 0.431, 0.397]
        std = [1, 1, 1]
        normalize = transforms.Normalize(mean=mean, std=std)

        negmean = [x * -1 for x in mean]
        revNormalize = transforms.Normalize(mean=negmean, std=std)

        # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
        # - Removed per channel mean subtraction for CPU.
        if (self.device == "cpu"):
            self.transform = transforms.Compose([transforms.ToTensor()])
            self.TP = transforms.Compose([transforms.ToPILImage()])
        else:
            self.transform = transforms.Compose(
                [transforms.ToTensor(), normalize])
            self.TP = transforms.Compose(
                [revNormalize, transforms.ToPILImage()])

        # Load data
        self.videoFrames = dataloader.Images(frame0=frame0,
                                             frame1=frame1,
                                             transform=self.transform)
        self.videoFramesloader = torch.utils.data.DataLoader(
            self.videoFrames, batch_size=self.batch_size, shuffle=False)

        # Initialize model
        self.flowComp = model.UNet(6, 4)
        self.flowComp.to(self.device)
        for param in self.flowComp.parameters():
            param.requires_grad = False
        self.ArbTimeFlowIntrp = model.UNet(20, 5)
        self.ArbTimeFlowIntrp.to(self.device)
        for param in self.ArbTimeFlowIntrp.parameters():
            param.requires_grad = False

        self.flowBackWarp = model.backWarp(self.videoFrames.dim[0],
                                           self.videoFrames.dim[1],
                                           self.device)
        self.flowBackWarp = self.flowBackWarp.to(self.device)

        dict1 = torch.load(self.checkpoint, map_location='cpu')
        self.ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        self.flowComp.load_state_dict(dict1['state_dictFC'])
示例#3
0
    def __init__(self, model_directory: str, sf: int, height: int, width: int, batch_size=1, **ssm):
        # sf
        self.sf = sf
        self.batch_size = batch_size

        # Check if need to expand image
        self.h_w = [int(math.ceil(height / 32) * 32 - height) if height % 32 else 0,
                    int(math.ceil(width / 32) * 32) - width if width % 32 else 0]
        dim = [height + self.h_w[0], width + self.h_w[1]]

        cuda_availability = torch.cuda.is_available()
        device = torch.device("cuda:0" if cuda_availability else "cpu")

        # Initialize model
        self.flowComp = model.UNet(6, 4)

        self.flowComp.to(device)

        for param in self.flowComp.parameters():
            param.requires_grad = False
        self.ArbTimeFlowIntrp = model.UNet(20, 5)
        self.ArbTimeFlowIntrp.to(device)
        for param in self.ArbTimeFlowIntrp.parameters():
            param.requires_grad = False
        self.flowBackWarp = model.backWarp(dim[1], dim[0], device)
        self.flowBackWarp = self.flowBackWarp.to(device)
        dict1 = torch.load(model_directory, map_location='cpu')
        self.ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        self.flowComp.load_state_dict(dict1['state_dictFC'])

        self.ndarray2tensor = {True: self.ndarray2cuda_tensor,
                               False: self.ndarray2cpu_tensor
                               }[cuda_availability]
        self.tensor2ndarray = {True: lambda frames: (frames.detach() * 255).byte()
                                                    [:, :, self.h_w[0]:, self.h_w[1]:].permute(0, 2, 3, 1)
                                                        .cpu().numpy()[:, :, :, ::-1],
                               False: lambda frames: numpy.transpose(
                                   (numpy.array(frames.detach().cpu()) * 255).astype(numpy.uint8)
                                   [:, ::-1, self.h_w[0]:, self.h_w[1]:], (0, 2, 3, 1))
                               }[cuda_availability]
        self.batch = torch.cuda.FloatTensor(batch_size + 1, 3, dim[0], dim[1]) if cuda_availability \
            else torch.FloatTensor(batch_size + 1, 3, dim[0], dim[1])
示例#4
0
import math, code

writer = SummaryWriter('test_log_without_keypoints')

###Loss and Optimizer
L1_lossFn = torch.nn.L1Loss()
MSE_LossFn = torch.nn.MSELoss()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn1 = model.UNet(6, 4)
cnn2 = model.UNet(20, 5)
cnn1.to(device)
cnn2.to(device)

###Initialze backward warpers for train and validation datasets
trainFlowBackWarp = model.backWarp(352, 352, device)
trainFlowBackWarp = trainFlowBackWarp.to(device)
validationFlowBackWarp = model.backWarp(640, 352, device)
validationFlowBackWarp = validationFlowBackWarp.to(device)

###Create transform to display image from tensor
revNormalize = torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5],
                                                std=[1, 1, 1])
TP = torchvision.transforms.Compose(
    [revNormalize, torchvision.transforms.ToPILImage()])

vgg16 = torchvision.models.vgg16(pretrained=True)
vgg16_conv = torch.nn.Sequential(*list(vgg16.children())[0][:22])
vgg16_conv.to(device)
# code.interact(local=dict(globals(), **locals()))
# Freeze convolutional weights
示例#5
0
def main():
    # Check if arguments are okay
    error = check()
    if error:
        print(error)
        exit(1)

    # Create extraction folder and extract frames
    IS_WINDOWS = 'Windows' == platform.system()
    extractionDir = "tmpSuperSloMo"
    if not IS_WINDOWS:
        # Assuming UNIX-like system where "." indicates hidden directories
        extractionDir = "." + extractionDir
    if os.path.isdir(extractionDir):
        rmtree(extractionDir)
    os.mkdir(extractionDir)
    if IS_WINDOWS:
        FILE_ATTRIBUTE_HIDDEN = 0x02
        # ctypes.windll only exists on Windows
        ctypes.windll.kernel32.SetFileAttributesW(extractionDir,
                                                  FILE_ATTRIBUTE_HIDDEN)

    extractionPath = os.path.join(extractionDir, "input")
    outputPath = os.path.join(extractionDir, "output")
    os.mkdir(extractionPath)
    os.mkdir(outputPath)
    error = extract_frames(args.video, extractionPath)
    if error:
        print(error)
        exit(1)

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    #  import ipdb; ipdb.set_trace()

    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()])
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=extractionPath, transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames,
                                                    batch_size=args.batch_size,
                                                    shuffle=False)

    # Initialize model
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1],
                                  device)
    flowBackWarp = flowBackWarp.to(device)

    dict1 = torch.load(args.checkpoint, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 1

    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            # Save reference frames in output folder
            for batchIndex in range(args.batch_size):
                (TP(frame0[batchIndex].detach())).resize(
                    videoFrames.origDim, Image.BILINEAR).save(
                        os.path.join(
                            outputPath,
                            str(frameCounter + args.sf * batchIndex) + ".jpg"))
            frameCounter += 1

            # Generate intermediate frames
            for intermediateIndex in range(1, args.sf):
                t = intermediateIndex / args.sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # Save intermediate frame
                for batchIndex in range(args.batch_size):
                    (TP(Ft_p[batchIndex].cpu().detach())).resize(
                        videoFrames.origDim, Image.BILINEAR).save(
                            os.path.join(
                                outputPath,
                                str(frameCounter + args.sf * batchIndex) +
                                ".jpg"))
                frameCounter += 1

            # Set counter accounting for batching of frames
            frameCounter += args.sf * (args.batch_size - 1)

    # Generate video from interpolated frames
    create_video(outputPath)

    # Remove temporary files
    rmtree(extractionDir)

    exit(0)
示例#6
0
def setup_back_warp(w, h):
    global back_warp
    with torch.set_grad_enabled(False):
        back_warp = model.backWarp(w, h, device).to(device)
示例#7
0
parser.add_argument("--train_batch_size", type=int, default=4, help='batch size for training.')
parser.add_argument("--validation_batch_size", type=int, default=10, help='batch size for validation. Default: 10.')
parser.add_argument("--init_learning_rate", type=float, default=0.0001, help='set initial learning rate. Default: 0.0001.')
parser.add_argument("--milestones", type=list, default=[100, 150], help='Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]')
parser.add_argument("--progress_iter", type=int, default=100, help='frequency of reporting progress and validation. N: after every N iterations. Default: 100.')
parser.add_argument("--checkpoint_epoch", type=int, default=5, help='checkpoint saving frequency. N: after every N epochs.Default: 5.')
parser.add_argument("--logfolder", type=str, required=True, default='log', help='path of log folder.')
args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
flowComp = model.UNet(2, 4)
flowComp.to(device)
ArbTimeFlowIntrp = model.UNet(12, 5)
ArbTimeFlowIntrp.to(device)

trainFlowBackWarp      = model.backWarp(448, 448, device)
trainFlowBackWarp      = trainFlowBackWarp.to(device)
validationFlowBackWarp = model.backWarp(512, 512, device)
validationFlowBackWarp = validationFlowBackWarp.to(device)

mean = [0.5]
std  = [1]
normalize = transforms.Normalize(mean=mean,
                                 std=std)
transform = transforms.Compose([transforms.ToTensor(), normalize])

trainset = dataloader.AsrNet(root=args.dataset_root + '/train', transform=transform, inferNum=args.infer_num, train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True)

validationset = dataloader.AsrNet(root=args.dataset_root + '/validation', transform=transform, randomCropSize=(512, 512), inferNum=args.infer_num, train=False)
validationloader = torch.utils.data.DataLoader(validationset, batch_size=args.validation_batch_size, shuffle=False)
示例#8
0
def train():
    global writer
    # For parsing commandline arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        type=str,
        required=True,
        help='path to dataset folder containing train-test-validation folders')
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        required=True,
                        help='path to folder for saving checkpoints')
    parser.add_argument("--checkpoint",
                        type=str,
                        help='path of checkpoint for pretrained model')
    parser.add_argument(
        "--train_continue",
        type=bool,
        default=False,
        help=
        'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.'
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        help='number of epochs to train. Default: 200.')
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=3,
                        help='batch size for training. Default: 6.')
    parser.add_argument("--validation_batch_size",
                        type=int,
                        default=6,
                        help='batch size for validation. Default: 10.')
    parser.add_argument("--init_learning_rate",
                        type=float,
                        default=0.0001,
                        help='set initial learning rate. Default: 0.0001.')
    parser.add_argument(
        "--milestones",
        type=list,
        default=[25, 50],
        help=
        'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]'
    )
    parser.add_argument(
        "--progress_iter",
        type=int,
        default=200,
        help=
        'frequency of reporting progress and validation. N: after every N iterations. Default: 100.'
    )
    parser.add_argument(
        "--checkpoint_epoch",
        type=int,
        default=5,
        help=
        'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.'
    )
    args = parser.parse_args()

    ##[TensorboardX](https://github.com/lanpa/tensorboardX)
    ### For visualizing loss and interpolated frames

    ###Initialize flow computation and arbitrary-time flow interpolation CNNs.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)

    ###Initialze backward warpers for train and validation datasets

    train_W_dim = 352
    train_H_dim = 352

    trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device)
    trainFlowBackWarp = trainFlowBackWarp.to(device)
    validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim,
                                            device)
    validationFlowBackWarp = validationFlowBackWarp.to(device)

    ###Load Datasets

    # Channel wise mean calculated on custom training dataset
    # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994]
    mean = [0.5] * 3
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train',
                                     randomCropSize=(train_W_dim, train_H_dim),
                                     transform=transform,
                                     train=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    validationset = dataloader.SuperSloMo(
        root=args.dataset_root + '/validation',
        transform=transform,
        randomCropSize=(2 * train_W_dim, train_H_dim),
        train=False)
    validationloader = torch.utils.data.DataLoader(
        validationset,
        batch_size=args.validation_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    print(trainset, validationset)

    ###Create transform to display image from tensor

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)
    TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    ###Utils

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    ###Loss and Optimizer

    L1_lossFn = nn.L1Loss()
    MSE_LossFn = nn.MSELoss()

    if args.train_continue:
        dict1 = torch.load(args.checkpoint)
        last_epoch = dict1['epoch'] * len(trainloader)
    else:
        last_epoch = -1

    params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())

    optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True)
    # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True)

    # scheduler to decrease learning rate by a factor of 10 at milestones.
    # Patience suggested value:
    # patience = number of item in train dataset / train_batch_size * (Number of epochs patience)
    # It does say epoch, but in this case, the number of progress iterations is what's really being worked with.
    # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count)
    # If the model seems to equalize fast, reduce the number of epochs accordingly.

    # scheduler = optim.lr_scheduler.CyclicLR(optimizer,
    #                                         base_lr=1e-8,
    #                                         max_lr=9.0e-3,
    #                                         step_size_up=3500,
    #                                         mode='triangular2',
    #                                         cycle_momentum=False,
    #                                         last_epoch=last_epoch)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=len(trainloader) * 3,
        cooldown=len(trainloader) * 2,
        verbose=True,
        min_lr=1e-8)

    # Changed to use this to ensure a more adaptive model.
    # The changed model used here seems to converge or plateau faster with more rapid swings over time.
    # As such letting the model deal with stagnation more proactively than at a set stage seems more useful.

    ###Initializing VGG16 model for perceptual loss

    vgg16 = torchvision.models.vgg16(pretrained=True)
    vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
    vgg16_conv_4_3.to(device)

    for param in vgg16_conv_4_3.parameters():
        param.requires_grad = False

    # Validation function

    def validate():
        # For details see training.
        psnr = 0
        tloss = 0
        flag = 1
        with torch.no_grad():
            for validationIndex, (validationData,
                                  validationFrameIndex) in enumerate(
                                      validationloader, 0):
                frame0, frameT, frame1 = validationData

                I0 = frame0.to(device)
                I1 = frame1.to(device)
                IFrame = frameT.to(device)

                torch.cuda.empty_cache()
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = model.getFlowCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
                torch.cuda.empty_cache()
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0
                # torch.cuda.empty_cache()
                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = model.getWarpCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # For tensorboard
                if (flag):
                    retImg = torchvision.utils.make_grid([
                        revNormalize(frame0[0]),
                        revNormalize(frameT[0]),
                        revNormalize(Ft_p.cpu()[0]),
                        revNormalize(frame1[0])
                    ],
                                                         padding=10)
                    flag = 0

                # loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                # torch.cuda.empty_cache()
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))

                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)
                torch.cuda.empty_cache()
                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                # torch.cuda.empty_cache()
                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
                tloss += loss.item()

                # psnr
                MSE_val = MSE_LossFn(Ft_p, IFrame)
                psnr += (10 * log10(1 / MSE_val.item()))
                torch.cuda.empty_cache()

        return (psnr / len(validationloader)), (tloss /
                                                len(validationloader)), retImg

    ### Initialization

    if args.train_continue:
        ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        flowComp.load_state_dict(dict1['state_dictFC'])

        optimizer.load_state_dict(dict1.get('state_optimizer', {}))
        scheduler.load_state_dict(dict1.get('state_scheduler', {}))

        for param_group in optimizer.param_groups:
            param_group['lr'] = dict1.get('learningRate',
                                          args.init_learning_rate)

    else:
        dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    ### Training

    import time

    start = time.time()
    cLoss = dict1['loss']
    valLoss = dict1['valLoss']
    valPSNR = dict1['valPSNR']
    checkpoint_counter = 0

    ### Main training loop

    optimizer.step()

    for epoch in range(dict1['epoch'] + 1, args.epochs):
        print("Epoch: ", epoch)

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        iLoss = 0

        for trainIndex, (trainData,
                         trainFrameIndex) in enumerate(trainloader, 0):

            ## Getting the input and the target from the training set
            frame0, frameT, frame1 = trainData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)
            optimizer.zero_grad()
            # torch.cuda.empty_cache()
            # Calculate flow between reference frames I0 and I1
            flowOut = flowComp(torch.cat((I0, I1), dim=1))

            # Extracting flows between I0 and I1 - F_0_1 and F_1_0
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = model.getFlowCoeff(trainFrameIndex, device)

            # Calculate intermediate flows
            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)
            torch.cuda.empty_cache()
            # Calculate optical flow residuals and visibility maps
            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            # Extract optical flow residuals and visibility maps
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0
            # torch.cuda.empty_cache()
            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)
            # torch.cuda.empty_cache()
            wCoeff = model.getWarpCoeff(trainFrameIndex, device)
            torch.cuda.empty_cache()
            # Calculate final intermediate frame
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            # torch.cuda.empty_cache()

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            # torch.cuda.empty_cache()
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        trainFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            # torch.cuda.empty_cache()
            # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4
            # since the loss in paper is calculated for input pixels in range 0-255
            # and the input to our network is in range 0-1
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # Backpropagate

            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())

            iLoss += loss.item()
            torch.cuda.empty_cache()
            # Validation and progress every `args.progress_iter` iterations
            if ((trainIndex % args.progress_iter) == args.progress_iter - 1):
                # Increment scheduler count
                scheduler.step(iLoss / args.progress_iter)

                end = time.time()

                psnr, vLoss, valImg = validate()
                optimizer.zero_grad()
                # torch.cuda.empty_cache()
                valPSNR[epoch].append(psnr)
                valLoss[epoch].append(vLoss)

                # Tensorboard
                itr = trainIndex + epoch * (len(trainloader))

                writer.add_scalars(
                    'Loss', {
                        'trainLoss': iLoss / args.progress_iter,
                        'validationLoss': vLoss
                    }, itr)
                writer.add_scalar('PSNR', psnr, itr)

                writer.add_image('Validation', valImg, itr)
                #####

                endVal = time.time()

                print(
                    " Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %.1e"
                    % (iLoss / args.progress_iter, trainIndex,
                       len(trainloader), end - start, vLoss, psnr,
                       endVal - end, get_lr(optimizer)))

                # torch.cuda.empty_cache()
                cLoss[epoch].append(iLoss / args.progress_iter)
                iLoss = 0
                start = time.time()

        # Create checkpoint after every `args.checkpoint_epoch` epochs
        if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1:
            dict1 = {
                'Detail': "End to end Super SloMo.",
                'epoch': epoch,
                'timestamp': datetime.datetime.now(),
                'trainBatchSz': args.train_batch_size,
                'validationBatchSz': args.validation_batch_size,
                'learningRate': get_lr(optimizer),
                'loss': cLoss,
                'valLoss': valLoss,
                'valPSNR': valPSNR,
                'state_dictFC': flowComp.state_dict(),
                'state_dictAT': ArbTimeFlowIntrp.state_dict(),
                'state_optimizer': optimizer.state_dict(),
                'state_scheduler': scheduler.state_dict()
            }
            torch.save(
                dict1, args.checkpoint_dir + "/SuperSloMo" +
                str(checkpoint_counter) + ".ckpt")
            checkpoint_counter += 1
示例#9
0
def main():
    extractPath = "./video_interpolation"
    prepare_folders(extractPath)

    video_to_images(args.video, os.path.join(extractPath, "input"))

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    normalize = torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[1, 1, 1])
    revNormalize = torchvision.transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1, 1, 1])

    if (device == "cpu"):
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        TP = torchvision.transforms.Compose([torchvision.transforms.ToPILImage()])
    else:
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize])
        TP = torchvision.transforms.Compose([revNormalize, torchvision.transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=os.path.join(extractPath, "input"), transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames, batch_size=2, shuffle=False)

    # Initialize model
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1], device)
    flowBackWarp = flowBackWarp.to(device)

    dict1 = torch.load("./checkpoints/Interpolation0.ckpt", map_location=device)
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 0

    # batch_size = 2
    with torch.no_grad():
        for frameIndex, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:,:2,:,:]
            F_1_0 = flowOut[:,2:,:,:]

            # Save reference frames in output folder
            for batchIndex in range(2):
                (TP(frame0[batchIndex].detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(os.path.join(extractPath, "output"), "frame{:05d}.png".format(frameCounter + 2 * batchIndex)))
            frameCounter += 1

            # Generate intermediate frame

            t = float(1) / 2
            temp = -t * (1 - t)
            fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

            intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))

            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0   = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1   = 1 - V_t_0

            g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

            wCoeff = [1 - t, t]

            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Save intermediate frame
            for batchIndex in range(2):
                (TP(Ft_p[batchIndex].cpu().detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(os.path.join(extractPath, "output"), "frame{:05d}.png".format(frameCounter + 2 * batchIndex)))
            frameCounter += 1

            frameCounter += 2
    # Generate video from interpolated frames
    create_video(os.path.join(extractPath, "output"), os.path.join(extractPath, "output"))
示例#10
0
    ArbTimeFlowIntrp.load_state_dict(dict1["state_dictAT"])
    flowComp.load_state_dict(dict1["state_dictFC"])
    print("Pretrained model loaded!")
else:
    # start logging info in comet-ml
    if not args.nocomet:
        comet_exp = Experiment(workspace=args.workspace, project_name=args.projectname)
        # comet_exp.log_parameters(flatten_opts(args))
    else:
        comet_exp = None
    dict1 = {"loss": [], "valLoss": [], "valPSNR": [], "valSSIM": [], "epoch": -1}

###Initialze backward warpers for train and validation datasets


trainFlowBackWarp = model.backWarp(128, 128, device)
trainFlowBackWarp = trainFlowBackWarp.to(device)
validationFlowBackWarp = model.backWarp(128, 128, device)
validationFlowBackWarp = validationFlowBackWarp.to(device)


###Load Datasets


# Channel wise mean calculated on adobe240-fps training dataset
mean = [0.5, 0.5, 0.5]
std = [1, 1, 1]
normalize = transforms.Normalize(mean=mean, std=std)
transform = transforms.Compose([transforms.ToTensor(), normalize])

trainset = dataloader.SuperSloMo(
示例#11
0
文件: test.py 项目: brainma/ASRNet
def main():
    error = check()
    if error:
        print(error)
        exit(1)
    outputPath = args.outFolder
    if not os.path.isdir(outputPath):
        os.mkdir(outputPath)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    mean = [0.5]
    std  = [1]
    normalize = transforms.Normalize(mean=mean,
                                     std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()])
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Data Loader
    frames = dataloader.AsrNetTest(root=args.inFolder, transform=transform)
    framesLoader = torch.utils.data.DataLoader(frames, batch_size=args.batch_size, shuffle=False)

    # Network arch
    flowComp = model.UNet(2, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False
    ArbTimeFlowIntrp = model.UNet(12, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(512, 512, device)
    flowBackWarp = flowBackWarp.to(device)

    dict1 = torch.load(args.checkpoint, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])
    dicomTemplate = frames.dicom_template
    dicomTemplate.NumberOfFrames = 1

    frameCounter = 1

    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(framesLoader, 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            flowIn = torch.cat((I0, I1), dim=1)
            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:,:2,:,:]
            F_1_0 = flowOut[:,2:,:,:]

            for batchIndex in range(args.batch_size):
                (TP(frame0[batchIndex].detach())).save(os.path.join(outputPath, 'img' + str(frameCounter) + ".png"), compress_level = 0)
                copy(os.path.join(args.inFolder, 'img'+str(frameCounter) +'.dcm' ), args.outDicomFolder)
            frameCounter += 1

            for intermediateIndex in range(1, args.infer_num + 1):
                t = intermediateIndex / (args.infer_num + 1)
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)
                intrpIn = torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)
                intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0   = F.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1   = 1 - V_t_0
                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # Save output frame
                for batchIndex in range(args.batch_size):
                    img = TP(Ft_p[batchIndex].cpu().detach())
                    img.save(os.path.join(outputPath, 'img' + str(frameCounter) + "_p.png"), compress_level = 0)
                    img_array = np.asarray(img)
                    dicomTemplate.PixelData = img_array.tobytes()
                    dicomTemplate.save_as(os.path.join(args.outDicomFolder, 'img' + str(frameCounter)+ '_p.dcm'))
                frameCounter += 1

            frameCounter += (args.infer_num + 1) * (args.batch_size - 1)

    exit(0)
示例#12
0
def main():
    # Check if arguments are okay
    error = check()
    if error:
        print(error)
        exit(1)

    # Create extraction folder and extract frames
    # Assuming UNIX-like system where "." indicates hidden directories
    extractionDir = ".tmpSuperSloMo"
    # if os.path.isdir(extractionDir):
    #     rmtree(extractionDir)
    # os.mkdir(extractionDir)

    extractionPath = os.path.join(extractionDir, "input")
    outputPath     = os.path.join(extractionDir, "output")
    # os.mkdir(extractionPath)
    # os.mkdir(outputPath)
    # error = extract_frames(args.video, extractionPath)
    # if error:
    #     print(error)
    #     exit(1)

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    mean = [0.429, 0.431, 0.397]
    std  = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean,
                                     std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()])
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=extractionPath, transform=transform)
    # pdb.set_trace()
    # len(videoFrames[0]) ==> 2
    # (Pdb) videoFrames[0][0].size()
    # torch.Size([3, 512, 960])


    videoFramesloader = torch.utils.data.DataLoader(videoFrames, batch_size=1, shuffle=False)

    # Initialize model
    # UNet(inChannels, outChannels)
    # flow Computation !!!
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False

    # arbitary-time
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1], device)
    flowBackWarp = flowBackWarp.to(device)
    flowBackWarp = amp.initialize(flowBackWarp, opt_level= "O1")
    # pdb.set_trace()
    # (Pdb) videoFrames.dim[0], videoFrames.dim[1]
    # (960, 512)

    dict1 = torch.load(args.checkpoint, map_location='cpu')
    # dict_keys(['Detail', 'epoch', 'timestamp', 'trainBatchSz', 'validationBatchSz', 
    # 'learningRate', 'loss', 'valLoss', 'valPSNR', 'state_dictFC', 'state_dictAT'])

    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    ArbTimeFlowIntrp = amp.initialize(ArbTimeFlowIntrp, opt_level= "O1")
    flowComp = amp.initialize(flowComp, opt_level= "O1")

    # Interpolate frames
    frameCounter = 1

    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            # pdb.set_trace()
            # torch.Size([1, 3, 512, 960])

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:,:2,:,:]
            F_1_0 = flowOut[:,2:,:,:]
            # (Pdb) pp flowOut.size()
            # torch.Size([1, 4, 512, 960])
            # (Pdb) pp F_0_1.size()
            # torch.Size([1, 2, 512, 960])
            # (Pdb) pp F_1_0.size()
            # torch.Size([1, 2, 512, 960])

            # pdb.set_trace()
            # (Pdb) pp I0.size(), I1.size()
            # (torch.Size([1, 3, 512, 960]), torch.Size([1, 3, 512, 960]))
            # (Pdb) pp torch.cat((I0, I1), dim=1).size()
            # torch.Size([1, 6, 512, 960])

            # Save reference frames in output folder
            (TP(frame0[0].detach())).resize(videoFrames.origDim, Image.BILINEAR).save(\
                os.path.join(outputPath, "{:06d}.png".format(frameCounter)))
            frameCounter += 1

            # Generate intermediate frames
            # (Pdb) for i in range(1, args.sf): print(i)
            # 1
            # 2
            # 3
            for intermediateIndex in range(1, args.sf):
                t = float(intermediateIndex) / args.sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                pdb.set_trace()
                # (Pdb) pp temp
                # -0.1875
                # (Pdb) pp fCoeff
                # [-0.1875, 0.0625, 0.5625, -0.1875]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                intrpOut = ArbTimeFlowIntrp(\
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, \
                        g_I1_F_t_1, g_I0_F_t_0), \
                    dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                pdb.set_trace()

                # pdb.set_trace()
                # (Pdb) intrpOut.size()
                # torch.Size([1, 5, 512, 960])

                V_t_0   = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1   = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                # pdb.set_trace()

                wCoeff = [1 - t, t]

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                del g_I0_F_t_0_f, g_I1_F_t_1_f, F_t_0_f, F_t_1_f, F_t_0, F_t_1, intrpOut, V_t_0, V_t_1, wCoeff
                torch.cuda.empty_cache()

                pdb.set_trace()

                # Save intermediate frame
                (TP(Ft_p[0].cpu().detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(outputPath, "{:06d}.png".format(frameCounter)))
                del Ft_p
                torch.cuda.empty_cache()

                frameCounter += 1

            del F_0_1, F_1_0, flowOut, I0, I1, frame0, frame1
            torch.cuda.empty_cache()

    # Generate video from interpolated frames
    # create_video(outputPath)

    # Remove temporary files
    # rmtree(extractionDir)

    exit(0)
示例#13
0
def evaluate_frame_dir(extractionPath):
    outputPath = os.path.join(extractionDir, "output")
    inputframe_dir = os.path.join(extractionDir, "inputframe")

    if op.exists(outputPath): rmtree(outputPath)
    if op.exists(inputframe_dir): rmtree(inputframe_dir)

    os.makedirs(outputPath, exist_ok=True)
    os.makedirs(inputframe_dir, exist_ok=True)

    frames_gt = os.listdir(extractionPath)
    frames_gt.sort()

    print(frames_gt)
    for ind, i in enumerate(frames_gt):
        if ind % (args.sf) == 0:
            shutil.copyfile(os.path.join(extractionPath, i),
                            os.path.join(inputframe_dir, i))

    video_time = time.time()
    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()
                                        ])  #添加一个转化函数,后面用于对每个img做filter
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=inputframe_dir, transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames,
                                                    batch_size=args.batch_size,
                                                    shuffle=False)

    # Initialize model
    #第一个unet,用于计算光流
    flowComp = model.UNet(6, 4)
    flowComp.to(device)

    #这里只需要inference,去掉训练bp
    for param in flowComp.parameters():
        param.requires_grad = False
    #第二个UNET,用于合成
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1],
                                  device)
    flowBackWarp = flowBackWarp.to(device)
    #加载模型的checkpoint
    dict1 = torch.load(args.checkpoint, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 0
    '''
    Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。 
    '''
    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            #print (I0.shape)

            #!!!!实现细节:在dim1连接起来!!!!
            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            #flowout中应该是前0,1维度为0-》1的光流,2,3维度为1-》0光流
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            # Save reference frames in output folder
            #保存原始视频帧
            for batchIndex in range(args.batch_size):
                pass
                #(TP(frame0[batchIndex].detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(outputPath, str(frameCounter + args.sf * batchIndex) + ".jpg"))
            frameCounter += 1
            sttime = time.time()
            # Generate intermediate frames
            for intermediateIndex in range(1, args.sf):
                t = intermediateIndex / args.sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                #第一个unet的初步结果,先看看这里的效果和我第一步骤对比
                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                #将上面一堆参数连接起来,送入下一个预测网络中
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                #注意这里将mask加入到两帧合成上的方式,还加入了时间序列
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)
                #print (Ft_p.shape)
                #这里看看他的中间结果怎么样?
                #Ft_p = (wCoeff[0] * g_I0_F_t_0 + wCoeff[1] * g_I1_F_t_1)
                #结果虽然时序上感觉有点抖动,但是清晰度上却是相当好,应该是loss函数的问题

                # Save intermediate frame
                #保存中间插入的帧
                for batchIndex in range(args.batch_size):
                    #ttp="%06d.jpg"%(frameCounter + args.sf * batchIndex)
                    ttp = frames_gt[frameCounter + args.sf * batchIndex]

                    ttp = os.path.join(outputPath, ttp)
                    #print (videoFrames.origDim) #(480, 270)
                    (TP(Ft_p[batchIndex].cpu().detach())).save(ttp)
                frameCounter += 1
            print("run %d iters, time:%f ,average:%f s/iter" %
                  (args.sf - 1, time.time() - sttime,
                   (time.time() - sttime) / (args.sf - 1)))
            # Set counter accounting for batching of frames
            frameCounter += args.sf * (args.batch_size - 1)

    ssim_kep = []
    psnr_kep = []
    for i in os.listdir(outputPath):
        gt_img = cv2.imread(os.path.join(extractionPath, i))
        genimg = cv2.imread(os.path.join(outputPath, i))

        #scale>0表示将标签图分辨率乘scale为目的分辨率,小于0表示使用生成图像分辨率
        scale = 1
        if scale > 0:
            target_shape = (int(gt_img.shape[1] * scale),
                            int(gt_img.shape[0] * scale))
        else:
            target_shape = (genimg.shape[1], genimg.shape[0])
        #print (genimg.shape)
        gt_img = cv2.resize(gt_img, target_shape)
        genimg = cv2.resize(genimg, target_shape)

        psnr = skimage.measure.compare_psnr(gt_img, genimg, 255)
        ssim = skimage.measure.compare_ssim(gt_img, genimg, multichannel=True)

        psnr_kep.append(psnr)
        ssim_kep.append(ssim)
    print("mean psnr:", np.mean(psnr_kep))
    print("mean ssim:", np.mean(ssim_kep))
    print("this video time used:", time.time() - video_time)
    # Generate video from interpolated frames
    #create_video(outputPath)

    # Remove temporary files
    rmtree(outputPath)
    rmtree(inputframe_dir)
示例#14
0
def main():
    # Check if arguments are okay
    error = check()
    if error:
        print(error)
        exit(1)

    # Create extraction folder and extract frames
    IS_WINDOWS = 'Windows' == platform.system()
    extractionDir = "tmpSuperSloMo"

    #这里需要有个文件夹放截出来的帧,其实没必要费力去把这个文件夹搞成隐藏的
    if not IS_WINDOWS:
        # Assuming UNIX-like system where "." indicates hidden directories
        extractionDir = "." + extractionDir

    if os.path.isdir(extractionDir):
        rmtree(extractionDir)
    os.mkdir(extractionDir)
    if IS_WINDOWS:
        FILE_ATTRIBUTE_HIDDEN = 0x02
        # ctypes.windll only exists on Windows
        ctypes.windll.kernel32.SetFileAttributesW(extractionDir,
                                                  FILE_ATTRIBUTE_HIDDEN)

    extractionPath = os.path.join(extractionDir, "input")
    outputPath = os.path.join(extractionDir, "output")
    os.mkdir(extractionPath)
    os.mkdir(outputPath)
    error = extract_frames(args.video, extractionPath)
    if error:
        print(error)
        exit(1)

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()
                                        ])  #添加一个转化函数,后面用于对每个img做filter
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=extractionPath, transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames,
                                                    batch_size=args.batch_size,
                                                    shuffle=False)

    # Initialize model
    #第一个unet,用于计算光流
    flowComp = model.UNet(6, 4)
    flowComp.to(device)

    #这里只需要inference,去掉训练bp
    for param in flowComp.parameters():
        param.requires_grad = False
    #第二个UNET,用于合成
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1],
                                  device)
    flowBackWarp = flowBackWarp.to(device)
    #加载模型的checkpoint
    dict1 = torch.load(args.checkpoint, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 1
    '''
    Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。 
    '''
    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)

            #!!!!实现细节:在dim1连接起来!!!!
            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            #flowout中应该是前0,1维度为0-》1的光流,2,3维度为1-》0光流
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            # Save reference frames in output folder
            #保存原始视频帧
            for batchIndex in range(args.batch_size):
                (TP(frame0[batchIndex].detach())).resize(
                    videoFrames.origDim, Image.BILINEAR).save(
                        os.path.join(
                            outputPath,
                            str(frameCounter + args.sf * batchIndex) + ".jpg"))
            frameCounter += 1

            # Generate intermediate frames
            for intermediateIndex in range(1, args.sf):
                t = intermediateIndex / args.sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                #第一个unet的初步结果,先看看这里的效果和我第一步骤对比
                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                #将上面一堆参数连接起来,送入下一个预测网络中
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                #注意这里将mask加入到两帧合成上的方式,还加入了时间序列
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                #这里看看他的中间结果怎么样?
                #Ft_p = (wCoeff[0] * g_I0_F_t_0 + wCoeff[1] * g_I1_F_t_1)
                #结果虽然时序上感觉有点抖动,但是清晰度上却是相当好,应该是loss函数的问题

                # Save intermediate frame
                #保存中间插入的帧
                for batchIndex in range(args.batch_size):
                    (TP(Ft_p[batchIndex].cpu().detach())).resize(
                        videoFrames.origDim, Image.BILINEAR).save(
                            os.path.join(
                                outputPath,
                                str(frameCounter + args.sf * batchIndex) +
                                ".jpg"))
                frameCounter += 1

            # Set counter accounting for batching of frames
            frameCounter += args.sf * (args.batch_size - 1)

    # Generate video from interpolated frames
    create_video(outputPath)

    # Remove temporary files
    rmtree(extractionDir)

    exit(0)
示例#15
0
def main():
    os.makedirs(args.output, exist_ok=True)
    outputPath = args.output

    if args.sf < 2:
        print("Slowmo factor must be at least 2.")
        return

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()])
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = dataloader.Video(root=extractionPath, transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames,
                                                    batch_size=args.batch_size,
                                                    shuffle=False)

    # Initialize model
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1],
                                  device)
    flowBackWarp = flowBackWarp.to(device)

    dict1 = torch.load(args.checkpoint, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 1

    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            # Save reference frames in output folder
            for batchIndex in range(args.batch_size):
                (TP(frame0[batchIndex].detach())).resize(
                    videoFrames.origDim, Image.BILINEAR).save(
                        os.path.join(
                            outputPath,
                            str(frameCounter + args.sf * batchIndex).zfill(8) +
                            ".png"))
            frameCounter += 1

            # Generate intermediate frames
            for intermediateIndex in range(1, args.sf):
                t = float(intermediateIndex) / args.sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # Save intermediate frame
                for batchIndex in range(args.batch_size):
                    (TP(Ft_p[batchIndex].cpu().detach())).resize(
                        videoFrames.origDim, Image.BILINEAR).save(
                            os.path.join(
                                outputPath,
                                str(frameCounter +
                                    args.sf * batchIndex).zfill(8) + ".png"))
                frameCounter += 1

            # Set counter accounting for batching of frames
            frameCounter += args.sf * (args.batch_size - 1)

    exit(0)
示例#16
0
 def UpdateFlowBackWarp(self, fw, fh):
     self.flowBackWarp = model.backWarp(fw, fh, self.device)
     self.flowBackWarp = self.flowBackWarp.to(self.device)