Esempio n. 1
0
def initFunc(opt, x):
    *_, h, w = x.shape
    width = ceilBy(64)(w)
    height = ceilBy(64)(h)
    opt.pad = nn.ReflectionPad2d((0, width - w, 0, height - h))
    w = w << 2
    h = h << 2
    opt.flow_warp = backWarp(width, height, device=x.device, dtype=x.dtype)
    opt.unpad = lambda im: im[:, :h, :w]
    return height, width
Esempio n. 2
0
    def forward(self, inp):
        inp = self.preprocess(inp)
        ref = [0] * 5 + [inp[:, 0]]
        supp = [0] * 5 + [inp[:, 1]]

        for i in range(len(ref) - 1, 0, -1):
            ref[i - 1] = F.avg_pool2d(input=ref[i],
                                      kernel_size=2,
                                      stride=2,
                                      count_include_pad=False)
            supp[i - 1] = F.avg_pool2d(input=supp[i],
                                       kernel_size=2,
                                       stride=2,
                                       count_include_pad=False)

        N, _, H, W = ref[0].shape
        flow = ref[0].new_zeros([N, 2, H >> 1, W >> 1])
        if not self.flow_warp or self.size != [H, W]:
            self.size = [H, W]
            self.flow_warp = []
            for r in ref:
                _, _, H, W = r.shape
                self.flow_warp.append(
                    backWarp(W,
                             H,
                             device=flow.device,
                             dtype=flow.dtype,
                             padding_mode='border'))
            assert not (H & 63 or W & 63)

        for level in range(len(ref)):
            upsampled_flow = F.interpolate(input=flow,
                                           scale_factor=2,
                                           mode='bilinear',
                                           align_corners=True) * 2.0

            f = self.flow_warp[level]
            flow = self.basic_module[level](torch.cat(
                [ref[level],
                 f(supp[level], upsampled_flow), upsampled_flow],
                1)) + upsampled_flow

        return flow
Esempio n. 3
0
#from torchvision.transforms import Normalize
from slomo import UNet, backWarp
from imageProcess import initModel, getStateDict
from config import config

log = logging.getLogger('Moe')
modelPath = './model/slomo/SuperSloMo.ckpt'
ramCoef = [.9 / x for x in (6418.7, 1393., 1156.3)]
#mean = [0.429, 0.431, 0.397]
#std  = [1, 1, 1]
#negMean = [-x for x in mean]
#identity = lambda x, *_: x
upTruncBy32 = lambda x: (-x & 0xffffffe0 ^ 0xffffffff) + 1
getFlowComp = lambda *_: UNet(6, 4)
getFlowIntrp = lambda *_: UNet(20, 5)
getFlowBack = lambda opt: backWarp(opt.width, opt.height, config.device(), config.dtype())

def getOpt(option):
  def opt():pass
  # Initialize model
  opt.model = modelPath
  dict1 = getStateDict(modelPath)
  opt.flowComp = initModel(opt, dict1['state_dictFC'], 'flowComp', getFlowComp)
  opt.ArbTimeFlowIntrp = initModel(opt, dict1['state_dictAT'], 'ArbTimeFlowIntrp', getFlowIntrp)
  opt.sf = option['sf']
  opt.firstTime = 1
  opt.notLast = 1
  opt.batchSize = 0
  if opt.sf < 2:
    raise RuntimeError('Error: --sf/slomo factor has to be at least 2')
  return opt
Esempio n. 4
0
def getFlowBack(opt, width, height):
  if opt.flowBackWarp:
    return opt.flowBackWarp
  opt.flowBackWarp = initModel(backWarp(width, height, config.device(), config.dtype()))
  return opt.flowBackWarp
Esempio n. 5
0
def run_slomo(vid_path,
              fps,
              sf,
              out_base='./download/',
              batch_size=1,
              model_path='./model/slomo/SuperSloMo.ckpt'):
    # Check if arguments are okay
    output = out_base + os.path.split(vid_path)[-1]
    output_tmp = out_base + 'tmp_' + os.path.split(vid_path)[-1]
    error = check(sf, batch_size, fps)
    if error:
        print(error)
        exit(1)

    # Create extraction folder and extract frames
    IS_WINDOWS = 'Windows' == platform.system()
    extractionDir = "tmpSuperSloMo"
    if not IS_WINDOWS:
        # Assuming UNIX-like system where "." indicates hidden directories
        extractionDir = "." + extractionDir
    if os.path.isdir(extractionDir):
        rmtree(extractionDir)
    os.mkdir(extractionDir)
    if IS_WINDOWS:
        FILE_ATTRIBUTE_HIDDEN = 0x02
        # ctypes.windll only exists on Windows
        ctypes.windll.kernel32.SetFileAttributesW(extractionDir,
                                                  FILE_ATTRIBUTE_HIDDEN)

    extractionPath = os.path.join(extractionDir, "input")
    outputPath = os.path.join(extractionDir, "output")
    os.mkdir(extractionPath)
    os.mkdir(outputPath)
    error = extract_frames(vid_path, extractionPath)
    if error:
        print(error)
        exit(1)

    # Initialize transforms
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)

    # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
    # - Removed per channel mean subtraction for CPU.
    if (device == "cpu"):
        transform = transforms.Compose([transforms.ToTensor()])
        TP = transforms.Compose([transforms.ToPILImage()])
    else:
        transform = transforms.Compose([transforms.ToTensor(), normalize])
        TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    # Load data
    videoFrames = slomo_vid_loader.Video(root=extractionPath,
                                         transform=transform)
    videoFramesloader = torch.utils.data.DataLoader(videoFrames,
                                                    batch_size=batch_size,
                                                    shuffle=False)

    # Initialize model
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    for param in flowComp.parameters():
        param.requires_grad = False
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)
    for param in ArbTimeFlowIntrp.parameters():
        param.requires_grad = False

    flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1],
                                  device)
    flowBackWarp = flowBackWarp.to(device)

    dict1 = torch.load(model_path, map_location='cpu')
    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
    flowComp.load_state_dict(dict1['state_dictFC'])

    # Interpolate frames
    frameCounter = 1

    with torch.no_grad():
        for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):

            I0 = frame0.to(device)
            I1 = frame1.to(device)

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            # Save reference frames in output folder
            for batchIndex in range(batch_size):
                (TP(frame0[batchIndex].detach())).resize(
                    videoFrames.origDim, Image.BILINEAR).save(
                        os.path.join(
                            outputPath,
                            str(frameCounter + sf * batchIndex) + ".jpg"))
            frameCounter += 1

            # Generate intermediate frames
            for intermediateIndex in range(1, sf):
                t = intermediateIndex / sf
                temp = -t * (1 - t)
                fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = flowBackWarp(I1, F_t_1)

                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)

                wCoeff = [1 - t, t]

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # Save intermediate frame
                for batchIndex in range(batch_size):
                    (TP(Ft_p[batchIndex].cpu().detach())).resize(
                        videoFrames.origDim, Image.BILINEAR).save(
                            os.path.join(
                                outputPath,
                                str(frameCounter + sf * batchIndex) + ".jpg"))
                frameCounter += 1

            # Set counter accounting for batching of frames
            frameCounter += sf * (batch_size - 1)

    # Generate video from interpolated frames
    create_video(outputPath, fps, output_tmp)
    combine(output_tmp, vid_path, output)
    # Remove temporary files
    rmtree(extractionDir)
    return output