def initFunc(opt, x): *_, h, w = x.shape width = ceilBy(64)(w) height = ceilBy(64)(h) opt.pad = nn.ReflectionPad2d((0, width - w, 0, height - h)) w = w << 2 h = h << 2 opt.flow_warp = backWarp(width, height, device=x.device, dtype=x.dtype) opt.unpad = lambda im: im[:, :h, :w] return height, width
def forward(self, inp): inp = self.preprocess(inp) ref = [0] * 5 + [inp[:, 0]] supp = [0] * 5 + [inp[:, 1]] for i in range(len(ref) - 1, 0, -1): ref[i - 1] = F.avg_pool2d(input=ref[i], kernel_size=2, stride=2, count_include_pad=False) supp[i - 1] = F.avg_pool2d(input=supp[i], kernel_size=2, stride=2, count_include_pad=False) N, _, H, W = ref[0].shape flow = ref[0].new_zeros([N, 2, H >> 1, W >> 1]) if not self.flow_warp or self.size != [H, W]: self.size = [H, W] self.flow_warp = [] for r in ref: _, _, H, W = r.shape self.flow_warp.append( backWarp(W, H, device=flow.device, dtype=flow.dtype, padding_mode='border')) assert not (H & 63 or W & 63) for level in range(len(ref)): upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 f = self.flow_warp[level] flow = self.basic_module[level](torch.cat( [ref[level], f(supp[level], upsampled_flow), upsampled_flow], 1)) + upsampled_flow return flow
#from torchvision.transforms import Normalize from slomo import UNet, backWarp from imageProcess import initModel, getStateDict from config import config log = logging.getLogger('Moe') modelPath = './model/slomo/SuperSloMo.ckpt' ramCoef = [.9 / x for x in (6418.7, 1393., 1156.3)] #mean = [0.429, 0.431, 0.397] #std = [1, 1, 1] #negMean = [-x for x in mean] #identity = lambda x, *_: x upTruncBy32 = lambda x: (-x & 0xffffffe0 ^ 0xffffffff) + 1 getFlowComp = lambda *_: UNet(6, 4) getFlowIntrp = lambda *_: UNet(20, 5) getFlowBack = lambda opt: backWarp(opt.width, opt.height, config.device(), config.dtype()) def getOpt(option): def opt():pass # Initialize model opt.model = modelPath dict1 = getStateDict(modelPath) opt.flowComp = initModel(opt, dict1['state_dictFC'], 'flowComp', getFlowComp) opt.ArbTimeFlowIntrp = initModel(opt, dict1['state_dictAT'], 'ArbTimeFlowIntrp', getFlowIntrp) opt.sf = option['sf'] opt.firstTime = 1 opt.notLast = 1 opt.batchSize = 0 if opt.sf < 2: raise RuntimeError('Error: --sf/slomo factor has to be at least 2') return opt
def getFlowBack(opt, width, height): if opt.flowBackWarp: return opt.flowBackWarp opt.flowBackWarp = initModel(backWarp(width, height, config.device(), config.dtype())) return opt.flowBackWarp
def run_slomo(vid_path, fps, sf, out_base='./download/', batch_size=1, model_path='./model/slomo/SuperSloMo.ckpt'): # Check if arguments are okay output = out_base + os.path.split(vid_path)[-1] output_tmp = out_base + 'tmp_' + os.path.split(vid_path)[-1] error = check(sf, batch_size, fps) if error: print(error) exit(1) # Create extraction folder and extract frames IS_WINDOWS = 'Windows' == platform.system() extractionDir = "tmpSuperSloMo" if not IS_WINDOWS: # Assuming UNIX-like system where "." indicates hidden directories extractionDir = "." + extractionDir if os.path.isdir(extractionDir): rmtree(extractionDir) os.mkdir(extractionDir) if IS_WINDOWS: FILE_ATTRIBUTE_HIDDEN = 0x02 # ctypes.windll only exists on Windows ctypes.windll.kernel32.SetFileAttributesW(extractionDir, FILE_ATTRIBUTE_HIDDEN) extractionPath = os.path.join(extractionDir, "input") outputPath = os.path.join(extractionDir, "output") os.mkdir(extractionPath) os.mkdir(outputPath) error = extract_frames(vid_path, extractionPath) if error: print(error) exit(1) # Initialize transforms device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") mean = [0.429, 0.431, 0.397] std = [1, 1, 1] normalize = transforms.Normalize(mean=mean, std=std) negmean = [x * -1 for x in mean] revNormalize = transforms.Normalize(mean=negmean, std=std) # Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 - # - Removed per channel mean subtraction for CPU. if (device == "cpu"): transform = transforms.Compose([transforms.ToTensor()]) TP = transforms.Compose([transforms.ToPILImage()]) else: transform = transforms.Compose([transforms.ToTensor(), normalize]) TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) # Load data videoFrames = slomo_vid_loader.Video(root=extractionPath, transform=transform) videoFramesloader = torch.utils.data.DataLoader(videoFrames, batch_size=batch_size, shuffle=False) # Initialize model flowComp = model.UNet(6, 4) flowComp.to(device) for param in flowComp.parameters(): param.requires_grad = False ArbTimeFlowIntrp = model.UNet(20, 5) ArbTimeFlowIntrp.to(device) for param in ArbTimeFlowIntrp.parameters(): param.requires_grad = False flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1], device) flowBackWarp = flowBackWarp.to(device) dict1 = torch.load(model_path, map_location='cpu') ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT']) flowComp.load_state_dict(dict1['state_dictFC']) # Interpolate frames frameCounter = 1 with torch.no_grad(): for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0): I0 = frame0.to(device) I1 = frame1.to(device) flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] # Save reference frames in output folder for batchIndex in range(batch_size): (TP(frame0[batchIndex].detach())).resize( videoFrames.origDim, Image.BILINEAR).save( os.path.join( outputPath, str(frameCounter + sf * batchIndex) + ".jpg")) frameCounter += 1 # Generate intermediate frames for intermediateIndex in range(1, sf): t = intermediateIndex / sf temp = -t * (1 - t) fCoeff = [temp, t * t, (1 - t) * (1 - t), temp] F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = flowBackWarp(I0, F_t_0) g_I1_F_t_1 = flowBackWarp(I1, F_t_1) intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f) wCoeff = [1 - t, t] Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # Save intermediate frame for batchIndex in range(batch_size): (TP(Ft_p[batchIndex].cpu().detach())).resize( videoFrames.origDim, Image.BILINEAR).save( os.path.join( outputPath, str(frameCounter + sf * batchIndex) + ".jpg")) frameCounter += 1 # Set counter accounting for batching of frames frameCounter += sf * (batch_size - 1) # Generate video from interpolated frames create_video(outputPath, fps, output_tmp) combine(output_tmp, vid_path, output) # Remove temporary files rmtree(extractionDir) return output