def forward(self, diff1, diff2, tensorInput1, tensorInput2): tensorJoin = torch.cat([diff1, diff2, tensorInput1, tensorInput2], 1) tensorConv1 = self.moduleConv1(tensorJoin) tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) tensorCombine = tensorUpsample5 + tensorConv5 tensorDeconv4 = self.moduleDeconv4(tensorCombine) tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) tensorCombine = tensorUpsample4 + tensorConv4 # tensorDot1_a = sepconv.FunctionSepconv()(self.modulePad_a(func.upsample(tensorInput1,size=(tensorInput1.shape[2]//4,tensorInput1.shape[3]//4),mode='bilinear',align_corners=True)), # self.mv1_a(tensorCombine),self.mh1_a(tensorCombine)) # tensorDot2_a = sepconv.FunctionSepconv()(self.modulePad_a(func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 4, tensorInput1.shape[3] // 4), mode='bilinear', # align_corners=True)), self.mv2_a(tensorCombine), self.mh2_a(tensorCombine)) tensorDeconv3 = self.moduleDeconv3(tensorCombine) tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) tensorCombine = tensorUpsample3 + tensorConv3 # tensorDot1_b = sepconv.FunctionSepconv()(self.modulePad_b(func.upsample(tensorInput1, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear', # align_corners=True)),self.mv1_b(tensorCombine), self.mh1_b(tensorCombine)) # tensorDot2_b = sepconv.FunctionSepconv()(self.modulePad_b(func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear', # align_corners=True)), self.mv2_b(tensorCombine), self.mh2_b(tensorCombine)) tensorDeconv2 = self.moduleDeconv2(tensorCombine) tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) tensorCombine = tensorUpsample2 + tensorConv2 tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) return tensorDot1 + tensorDot2
def forward(self, tensorInput1, tensorInput2): ''' tensorInput1/2 : [bcz, 3, height, width] diff: [bcz, 2, height, width] ''' tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) tensorConv1 = self.moduleConv1(tensorJoin) #[32, 128, 128] tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) #[64, 64, 64] tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) #[128, 32, 32] tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) #[256, 16, 16] tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) #[512, 8, 8] tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) #[512, 4, 4] tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) #[512, 8, 8] tensorCombine = tensorUpsample5 + tensorConv5 #[512, 8, 8] tensorDeconv4 = self.moduleDeconv4(tensorCombine) #[256, 8, 8] tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) #[256, 16, 16] tensorCombine = tensorUpsample4 + tensorConv4 #[256, 16, 16] tensorDeconv3 = self.moduleDeconv3(tensorCombine) #[128, 16, 16] tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) #[128, 32, 32] tensorCombine = tensorUpsample3 + tensorConv3 #[128, 32, 32] tensorDeconv2 = self.moduleDeconv2(tensorCombine) #[64, 32, 32] tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) #[64, 64, 64] tensorCombine = tensorUpsample2 + tensorConv2 #[64, 64, 64] tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) return tensorDot1 + tensorDot2
def forward(self, tensorInput1, tensorInput2): tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) tensorConv1 = self.moduleConv1(tensorJoin) tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) tensorCombine = tensorUpsample5 + tensorConv5 tensorDeconv4 = self.moduleDeconv4(tensorCombine) tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) tensorCombine = tensorUpsample4 + tensorConv4 tensorDeconv3 = self.moduleDeconv3(tensorCombine) tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) tensorCombine = tensorUpsample3 + tensorConv3 tensorDeconv2 = self.moduleDeconv2(tensorCombine) tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) tensorCombine = tensorUpsample2 + tensorConv2 tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) return tensorDot1 + tensorDot2
def forward(self, frames): assert np.all([f.shape == frames[0].shape for f in frames]) _, _, h, w = frames[0].shape if len(frames) == 4: frames = [frames[1], frames[2], frames[0], frames[3]] h_padded = False w_padded = False padded_frames = [] for frame in frames: if h % 32 != 0: pad_h = 32 - (h % 32) frame = F.pad(frame, (0, 0, 0, pad_h)) h_padded = True if w % 32 != 0: pad_w = 32 - (w % 32) frame = F.pad(frame, (0, pad_w, 0, 0)) w_padded = True padded_frames.append(frame) Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel( padded_frames) tensorDot1 = sepconv.FunctionSepconv()(self.modulePad( padded_frames[0]), Vertical1, Horizontal1) tensorDot2 = sepconv.FunctionSepconv()(self.modulePad( padded_frames[1]), Vertical2, Horizontal2) frame1 = tensorDot1 + tensorDot2 if h_padded: frame1 = frame1[:, :, 0:h, :] if w_padded: frame1 = frame1[:, :, :, 0:w] return frame1
def forward(self, tenFirst, tenSecond): tenConv1 = self.netConv1(torch.cat([ tenFirst, tenSecond ], 1)) tenConv2 = self.netConv2(torch.nn.functional.avg_pool2d(input=tenConv1, kernel_size=2, stride=2, count_include_pad=False)) tenConv3 = self.netConv3(torch.nn.functional.avg_pool2d(input=tenConv2, kernel_size=2, stride=2, count_include_pad=False)) tenConv4 = self.netConv4(torch.nn.functional.avg_pool2d(input=tenConv3, kernel_size=2, stride=2, count_include_pad=False)) tenConv5 = self.netConv5(torch.nn.functional.avg_pool2d(input=tenConv4, kernel_size=2, stride=2, count_include_pad=False)) tenDeconv5 = self.netUpsample5(self.netDeconv5(torch.nn.functional.avg_pool2d(input=tenConv5, kernel_size=2, stride=2, count_include_pad=False))) tenDeconv4 = self.netUpsample4(self.netDeconv4(tenDeconv5 + tenConv5)) tenDeconv3 = self.netUpsample3(self.netDeconv3(tenDeconv4 + tenConv4)) tenDeconv2 = self.netUpsample2(self.netDeconv2(tenDeconv3 + tenConv3)) tenCombine = tenDeconv2 + tenConv2 tenFirst = torch.nn.functional.pad(input=tenFirst, pad=[ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ], mode='replicate') tenSecond = torch.nn.functional.pad(input=tenSecond, pad=[ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ], mode='replicate') tenDot1 = sepconv.FunctionSepconv(tenInput=tenFirst, tenVertical=self.netVertical1(tenCombine), tenHorizontal=self.netHorizontal1(tenCombine)) tenDot2 = sepconv.FunctionSepconv(tenInput=tenSecond, tenVertical=self.netVertical2(tenCombine), tenHorizontal=self.netHorizontal2(tenCombine)) return tenDot1 + tenDot2
def forward(self, frame0, frame2): h0 = int(list(frame0.size())[2]) w0 = int(list(frame0.size())[3]) h2 = int(list(frame2.size())[2]) w2 = int(list(frame2.size())[3]) if h0 != h2 or w0 != w2: sys.exit('Frame sizes do not match') h_padded = False w_padded = False if h0 % 32 != 0: pad_h = 32 - (h0 % 32) frame0 = F.pad(frame0, (0, 0, 0, pad_h)) frame2 = F.pad(frame2, (0, 0, 0, pad_h)) h_padded = True if w0 % 32 != 0: pad_w = 32 - (w0 % 32) frame0 = F.pad(frame0, (0, pad_w, 0, 0)) frame2 = F.pad(frame2, (0, pad_w, 0, 0)) w_padded = True Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel( frame0, frame2) tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(frame0), Vertical1, Horizontal1) tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(frame2), Vertical2, Horizontal2) frame1 = tensorDot1 + tensorDot2 if h_padded: frame1 = frame1[:, :, 0:h0, :] if w_padded: frame1 = frame1[:, :, :, 0:w0] return frame1
def forward(self, Frame1, Frame3): h_1 = int(list(Frame1.size())[2]) w_1 = int(list(Frame1.size())[3]) h_3 = int(list(Frame3.size())[2]) w_3 = int(list(Frame3.size())[3]) # Make sure frame size is same if h_1 != h_3 or w_1 != w_3: sys.exit('Size mismatch') h_pad = False w_pad = False if w_1 % 32 != 0: pad_w = 32 - (w_1 % 32) Frame1 = F.pad(Frame1, (0, pad_w, 0, 0)) Frame3 = F.pad(Frame3, (0, pad_w, 0, 0)) w_pad = True if h_1 % 32 != 0: pad_h = 32 - (h_1 % 32) Frame1 = F.pad(Frame1, (0, 0, 0, pad_h)) Frame3 = F.pad(Frame3, (0, 0, 0, pad_h)) h_pad = True Ver1, Hor1, Ver2, Hor2 = self.estimate_kernel(Frame1, Frame3) tenDot1 = sepconv.FunctionSepconv()(self.modulePad(Frame1), Ver1, Hor1) tenDot2 = sepconv.FunctionSepconv()(self.modulePad(Frame3), Ver2, Hor2) Frame2 = tenDot1 + tenDot2 if h_pad: Frame2 = Frame2[:, :, 0:h_1, :] if w_pad: Frame2 = Frame2[:, :, :, 0:w_1] return Frame2, Ver1, Hor1, Ver2, Hor2
def forward(self, frames): _, f, _, h, w = frames.shape h_padded = False w_padded = False padded_frames = frames.clone() if h % 32 != 0: pad_h = 32 - (h % 32) padded_frames = F.pad(padded_frames, (0, 0, 0, pad_h)) h_padded = True if w % 32 != 0: pad_w = 32 - (w % 32) padded_frames = F.pad(padded_frames, (0, pad_w, 0, 0)) w_padded = True Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel( padded_frames) frame_before = int(0 + f / 4) frame_after = int(1 + f / 4) tensorDot1 = sepconv.FunctionSepconv()(self.modulePad( padded_frames[:, frame_before]), Vertical1, Horizontal1) tensorDot2 = sepconv.FunctionSepconv()(self.modulePad( padded_frames[:, frame_after]), Vertical2, Horizontal2) frame1 = tensorDot1 + tensorDot2 if h_padded: frame1 = frame1[:, :, 0:h, :] if w_padded: frame1 = frame1[:, :, :, 0:w] return frame1
def forward(self, tensorFirst, tensorSecond): tensorJoin = torch.cat([tensorFirst, tensorSecond], 1) tensorConv1 = self.moduleConv1(tensorJoin) tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) tensorCombine = tensorUpsample5 + tensorConv5 tensorDeconv4 = self.moduleDeconv4(tensorCombine) tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) tensorCombine = tensorUpsample4 + tensorConv4 tensorDeconv3 = self.moduleDeconv3(tensorCombine) tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) tensorCombine = tensorUpsample3 + tensorConv3 tensorDeconv2 = self.moduleDeconv2(tensorCombine) tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) tensorCombine = tensorUpsample2 + tensorConv2 print(self.moduleVertical1(tensorCombine).size()) print(self.moduleHorizontal1(tensorCombine).size()) #print(tensorCombine.size(0), tensorCombine.size(1),tensorCombine.size(2),tensorCombine.size(3)) sepconv.FunctionSepconv().forward( self.modulePad(tensorFirst), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) #tensorDot2 = sepconv.FunctionSepconv().forward(self.modulePad(tensorSecond), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) return torch.zeros(1, 1, 388 + 124, 584 + 56) #1,1,512,640
def forward(self, diff, tensorInput1, tensorInput2): ''' tensorInput1/2 : [bcz, 3, height, width] diff: [bcz, 2, height, width] ''' diff *= 2.0 # @ I multiply it by 2 just for favor of warping tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 tensorOptConv1 = self.optConv1(diff) #[32, 128, 128] tensorOptPool1 = self.optPool1(tensorOptConv1) tensorOptConv2 = self.optConv2(tensorOptPool1) #[64, 64, 64] tensorOptPool2 = self.optPool2(tensorOptConv2) tensorOptConv3 = self.optConv3(tensorOptPool2) #[128, 32, 32] tensorOptPool3 = self.optPool3(tensorOptConv3) tensorOptConv4 = self.optConv4(tensorOptPool3) #[256, 16, 16] tensorOptPool4 = self.optPool4(tensorOptConv4) tensorOptConv5 = self.optConv5(tensorOptPool4) #[512, 8, 8] tensorOptPool5 = self.optPool5(tensorOptConv5) tensorOptDeconv5 = self.optDeconv5(tensorOptPool5) tensorOptUpsample5 = self.optUpsample5(tensorOptDeconv5) tensorCombine = tensorOptUpsample5 + tensorOptConv5 tensorOptDeconv4 = self.optDeconv4(tensorCombine) tensorOptUpsample4 = self.optUpsample4(tensorOptDeconv4) tensorCombine = tensorOptUpsample4 + tensorOptConv4 tensorOptDeconv3 = self.optDeconv3(tensorCombine) tensorOptUpsample3 = self.optUpsample3(tensorOptDeconv3) tensorCombine = tensorOptUpsample3 + tensorOptConv3 tensorOptDeconv2 = self.optDeconv2(tensorCombine) tensorOptUpsample2 = self.optUpsample2(tensorOptDeconv2) tensorCombine = tensorOptUpsample2 + tensorOptConv2 tensorOptDeconv1 = self.optDeconv1(tensorCombine) tensorOptUpsample1 = self.optUpsample1(tensorOptDeconv1) tensorCombine = tensorOptUpsample1 + tensorOptConv1 tensorOptPred1 = self.optPred(tensorCombine) # Warp the raw image tensorWarp1 = self.opt.warp(tensorOptPred1, tensorInput2) # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 tensorConv1 = self.moduleConv1(tensorJoin) #[32, 128, 128] tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) #[64, 64, 64] tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) #[128, 32, 32] tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) #[256, 16, 16] tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) #[512, 8, 8] tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) #[512, 4, 4] tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) #[512, 8, 8] tensorCombine = tensorUpsample5 + tensorConv5 #[512, 8, 8] tensorDeconv4 = self.moduleDeconv4(tensorCombine) #[256, 8, 8] tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) #[256, 16, 16] tensorCombine = tensorUpsample4 + tensorConv4 #[256, 16, 16] tensorDeconv3 = self.moduleDeconv3(tensorCombine) #[128, 16, 16] tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) #[128, 32, 32] tensorCombine = tensorUpsample3 + tensorConv3 #[128, 32, 32] tensorDeconv2 = self.moduleDeconv2(tensorCombine) #[64, 32, 32] tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) #[64, 64, 64] tensorCombine = tensorUpsample2 + tensorConv2 #[64, 64, 64] tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) tensorDot = tensorDot1 + tensorDot2 tensorRet = self.fuse(torch.cat([tensorDot, tensorWarp1], 1)) # tensorRet = self.fuse(tensorDot) return tensorDot, tensorWarp1, tensorRet
def forward(self, tensorInput1, tensorInput2): ''' tensorInput1/2 : [bcz, 3, height, width] diff: [bcz, 2, height, width] ''' tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) x = tensorJoin x = self.conv1(x) # x = self.conv1_bn(x) conv1 = self.relu(x) x = self.pool(conv1) x = self.conv2(x) # x = self.conv2_bn(x) conv2 = self.relu(x) x = self.pool(conv2) x = self.conv3(x) # x = self.conv3_bn(x) conv3 = self.relu(x) x = self.pool(conv3) x = self.bottleneck(x) # x = self.bottleneck_bn(x) x = self.relu(x) x = nn.functional.upsample(x, scale_factor=2, mode='bilinear', align_corners=False) x = torch.cat([x, conv3], dim=1) x = self.deconv1(x) # x = self.deconv1_bn(x) x = self.relu(x) x = nn.functional.upsample(x, scale_factor=2, mode='bilinear', align_corners=False) x = torch.cat([x, conv2], dim=1) x = self.deconv2(x) # x = self.deconv2_bn(x) x = self.relu(x) x = nn.functional.upsample(x, scale_factor=2, mode='bilinear', align_corners=False) x = torch.cat([x, conv1], dim=1) x = self.deconv3(x) # x = self.deconv3_bn(x) x = self.relu(x) x = self.conv4(x) mask = nn.functional.tanh(x) # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 tensorConv1 = self.moduleConv1(tensorJoin) #[32, 128, 128] tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) #[64, 64, 64] tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) #[128, 32, 32] tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) #[256, 16, 16] tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) #[512, 8, 8] tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) #[512, 4, 4] tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) #[512, 8, 8] tensorCombine = tensorUpsample5 + tensorConv5 #[512, 8, 8] tensorDeconv4 = self.moduleDeconv4(tensorCombine) #[256, 8, 8] tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) #[256, 16, 16] tensorCombine = tensorUpsample4 + tensorConv4 #[256, 16, 16] tensorDeconv3 = self.moduleDeconv3(tensorCombine) #[128, 16, 16] tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) #[128, 32, 32] tensorCombine = tensorUpsample3 + tensorConv3 #[128, 32, 32] tensorDeconv2 = self.moduleDeconv2(tensorCombine) #[64, 32, 32] tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) #[64, 64, 64] tensorCombine = tensorUpsample2 + tensorConv2 #[64, 64, 64] tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) mask = 0.5 * (1.0 + mask) mask = mask.repeat([1, 3, 1, 1]) x = mask * tensorDot1 + (1.0 - mask) * tensorDot2 return x
def forward(self, tensorInput1, tensorInput2, tensorResidual=None, tensorHidden=None): ''' tensorInput1/2 : [bcz, 3, height, width] tensorResidual: [bcz, 3, height, width] tensorHidden:(tuple or None) ([bcz, hidden_dim, height, width]) When the LSTM_state is Noe, it means that its the first time step ''' batch_size = tensorInput1.size(0) # ------------------- LSTM Part -------------------- if tensorResidual is None: tensorResidual = var( torch.zeros(batch_size, tensorInput1.size(1), tensorInput1.size(2), tensorInput1.size(3))).cuda() tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual)) tensorH_next, tensorC_next = self.moduleLSTM( tensorEncRes) # Hence we also don't have the tensorHidden else: tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual)) tensorH_next, tensorC_next = self.moduleLSTM( tensorEncRes, tensorHidden) # ------------------- Encoder Part ----------------- tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) tensorConv1 = self.moduleConv1(tensorJoin) #[32, 128, 128] tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) #[64, 64, 64] tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) #[128, 32, 32] tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) #[256, 16, 16] tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) #[512, 8, 8] tensorPool5 = self.modulePool5(tensorConv5) # ------------------- Doceder Part ----------------- tensorDeconv5 = self.moduleDeconv5(tensorPool5) #[512, 4, 4] tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) #[512, 8, 8] tensorCombine = tensorUpsample5 + tensorConv5 #[512, 8, 8] tensorDeconv4 = self.moduleDeconv4(tensorCombine) #[256, 8, 8] tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) #[256, 16, 16] tensorCombine = tensorUpsample4 + tensorConv4 #[256, 16, 16] tensorDeconv3 = self.moduleDeconv3(tensorCombine) #[128, 16, 16] tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) #[128, 32, 32] tensorCombine = tensorUpsample3 + tensorConv3 #[128, 32, 32] tensorDeconv2 = self.moduleDeconv2(tensorCombine) #[64, 32, 32] tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) #[64, 64, 64] tensorCombine1 = tensorUpsample2 + tensorConv2 #[64, 64, 64] tensorCombine = torch.cat([tensorCombine1, tensorH_next], 1) tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical11(tensorCombine), self.moduleHorizontal11(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical22(tensorCombine), self.moduleHorizontal22(tensorCombine)) # Return the predictd tensor and the next state of convLSTM return tensorDot1 + tensorDot2, (tensorH_next, tensorC_next)
moduleNetwork.load_state_dict(torch.load("../models/" + predir)) for epoch in range(100): for n in range(train): #making train data image3b = torch.ones((1, 3, 178, 178)).cuda( ) ##25 pixel wider picture for each direction to synthesis kernel and I1, I,2 image1b = torch.ones((1, 3, 178, 178)).cuda() image1 = torch.ones(1, 3, 128, 128).cuda() image2 = torch.ones(1, 3, 128, 128).cuda() image3 = torch.ones(1, 3, 128, 128).cuda() #forward caluclation Kernel = moduleNetwork.forward(image1, image3) kernelDiv = torch.chunk(Kernel, 4, dim=3) tensorDot1 = sepconv.FunctionSepconv().forward( image1b, kernelDiv[0], kernelDiv[1]).detach() tensorDot2 = sepconv.FunctionSepconv().forward( image3b, kernelDiv[2], kernelDiv[3]).detach() tensorDot1.requires_grad = True tensorDot2.requires_grad = True tensorCombine = tensorDot1 + tensorDot2 #backward caluclation loss = loss_fn(tensorCombine, image2) value_loss = loss.item() loss.backward() kgrad1 = sepconv.FunctionSepconv().backward( tensorDot1.grad, (tensorCombine, image1b, kernelDiv[0], kernelDiv[1])) kgrad2 = sepconv.FunctionSepconv().backward( tensorDot2.grad,
def forward(self, tensorInput1, tensorInput2, tensorResidual=None, tensorHidden=None): ''' tensorInput1/2 : [bcz, 3, height, width] tensorResidual: [bcz, 3, height, width] tensorHidden:(tuple or None) ([bcz, hidden_dim, height, width]) When the LSTM_state is Noe, it means that its the first time step ''' batch_size = tensorInput1.size(0) # ------------------- LSTM Part -------------------- if tensorResidual is None: tensorResidual = var( torch.zeros(batch_size, tensorInput1.size(1), tensorInput1.size(2), tensorInput1.size(3))).cuda() tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual)) tensorH_next, tensorC_next = self.moduleLSTM( tensorEncRes) # Hence we also don't have the tensorHidden else: tensorEncRes = self.moduleDownH(self.moduleConvH(tensorResidual)) tensorH_next, tensorC_next = self.moduleLSTM( tensorEncRes, tensorHidden) # ------------------------- I use the convolution with stride of 2 to work as a downsample function~, which accords to the resolution of [128, 128], [64, 64], [32, 32] tensorL0 = tensorH_next tensorL1 = self.moduleDownLSTM1(tensorL0) tensorL2 = self.moduleDownLSTM2(tensorL1) # ------------------------- I use the convolution with stride of 2 to work as a downsample function~, which accords to the resolution of [128, 128], [64, 64], [32, 32] tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) tensorConv1 = self.moduleConv1(tensorJoin) tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) tensorCombine = tensorUpsample5 + tensorConv5 tensorDeconv4 = self.moduleDeconv4(tensorCombine) tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) tensorCombine = tensorUpsample4 + tensorConv4 # ------- LSTM combine ------------ tensorCombineL2 = torch.cat([tensorCombine, tensorL2], 1) # This channel is 256 + 128 = 384 # ------- LSTM combine ------------ tensorDot1_a = sepconv.FunctionSepconv()(self.modulePad_a( func.upsample(tensorInput1, size=(tensorInput1.shape[2] // 4, tensorInput1.shape[3] // 4), mode='bilinear', align_corners=True)), self.mv1_a_(tensorCombineL2), self.mh1_a_(tensorCombineL2)) tensorDot2_a = sepconv.FunctionSepconv()(self.modulePad_a( func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 4, tensorInput1.shape[3] // 4), mode='bilinear', align_corners=True)), self.mv2_a_(tensorCombineL2), self.mh2_a_(tensorCombineL2)) tensorDeconv3 = self.moduleDeconv3(tensorCombine) tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) tensorCombine = tensorUpsample3 + tensorConv3 # ------- LSTM combine ------------ tensorCombineL1 = torch.cat([tensorCombine, tensorL1], 1) # This channel is 128 + 64 = 192 # ------- LSTM combine ------------ tensorDot1_b = sepconv.FunctionSepconv()(self.modulePad_b( func.upsample(tensorInput1, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear', align_corners=True)), self.mv1_b_(tensorCombineL1), self.mh1_b_(tensorCombineL1)) tensorDot2_b = sepconv.FunctionSepconv()(self.modulePad_b( func.upsample(tensorInput2, size=(tensorInput1.shape[2] // 2, tensorInput1.shape[3] // 2), mode='bilinear', align_corners=True)), self.mv2_b_(tensorCombineL1), self.mh2_b_(tensorCombineL1)) tensorDeconv2 = self.moduleDeconv2(tensorCombine) tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) tensorCombine = tensorUpsample2 + tensorConv2 # ------- LSTM combine ------------ tensorCombineL0 = torch.cat([tensorCombine, tensorL0], 1) # This channel is 64 + 32 = 96 # ------- LSTM combine ------------ tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1_(tensorCombineL0), self.moduleHorizontal1_(tensorCombineL0)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorInput2), self.moduleVertical2_(tensorCombineL0), self.moduleHorizontal2_(tensorCombineL0)) return tensorDot1 + tensorDot2, tensorDot1_a + tensorDot2_a, tensorDot1_b + tensorDot2_b, ( tensorH_next, tensorC_next)
def forward(self, tensorInput1, tensorInput2): ''' tensorInput1/2 : [bcz, 3, height, width] diff: [bcz, 2, height, width] ''' tensorJoin = torch.cat([tensorInput1, tensorInput2], 1) # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 tensorOptConv1 = self.optConv1(tensorJoin) #[32, 128, 128] tensorOptPool1 = self.optPool1(tensorOptConv1) tensorOptConv2 = self.optConv2(tensorOptPool1) #[64, 64, 64] tensorOptPool2 = self.optPool2(tensorOptConv2) tensorOptConv3 = self.optConv3(tensorOptPool2) #[128, 32, 32] tensorOptPool3 = self.optPool3(tensorOptConv3) tensorOptConv4 = self.optConv4(tensorOptPool3) #[256, 16, 16] tensorOptPool4 = self.optPool4(tensorOptConv4) tensorOptConv5 = self.optConv5(tensorOptPool4) #[512, 8, 8] tensorOptPool5 = self.optPool5(tensorOptConv5) tensorOptDeconv5 = self.optDeconv5(tensorOptPool5) tensorOptUpsample5 = self.optUpsample5(tensorOptDeconv5) tensorCombine = tensorOptUpsample5 + tensorOptConv5 tensorOptDeconv4 = self.optDeconv4(tensorCombine) tensorOptUpsample4 = self.optUpsample4(tensorOptDeconv4) tensorCombine = tensorOptUpsample4 + tensorOptConv4 tensorOptDeconv3 = self.optDeconv3(tensorCombine) tensorOptUpsample3 = self.optUpsample3(tensorOptDeconv3) tensorCombine = tensorOptUpsample3 + tensorOptConv3 tensorOptDeconv2 = self.optDeconv2(tensorCombine) tensorOptUpsample2 = self.optUpsample2(tensorOptDeconv2) tensorCombine = tensorOptUpsample2 + tensorOptConv2 tensorOptDeconv1 = self.optDeconv1(tensorCombine) tensorOptUpsample1 = self.optUpsample1(tensorOptDeconv1) tensorCombine = tensorOptUpsample1 + tensorOptConv1 tensorOptPred = self.optPred(tensorCombine) # tensorOptPred1 = tensorOptPred[:,:2,:,:] # tensorOptPred2 = tensorOptPred[:,2:,:,:] # # Warp the raw image # tensorWarp1 = self.opt.warp(tensorOptPred1, tensorInput1) # tensorWarp2 = self.opt.warp(tensorOptPred2, tensorInput2) tensorWarp = self.opt.warp(tensorOptPred, tensorInput2) # ---------------- Predict the back-forward optical flow and warp the inputFrame2 Part1 tensorConv1 = self.moduleConv1(tensorJoin) #[32, 128, 128] tensorPool1 = self.modulePool1(tensorConv1) tensorConv2 = self.moduleConv2(tensorPool1) #[64, 64, 64] tensorPool2 = self.modulePool2(tensorConv2) tensorConv3 = self.moduleConv3(tensorPool2) #[128, 32, 32] tensorPool3 = self.modulePool3(tensorConv3) tensorConv4 = self.moduleConv4(tensorPool3) #[256, 16, 16] tensorPool4 = self.modulePool4(tensorConv4) tensorConv5 = self.moduleConv5(tensorPool4) #[512, 8, 8] tensorPool5 = self.modulePool5(tensorConv5) tensorDeconv5 = self.moduleDeconv5(tensorPool5) #[512, 4, 4] tensorUpsample5 = self.moduleUpsample5(tensorDeconv5) #[512, 8, 8] tensorCombine = tensorUpsample5 + tensorConv5 #[512, 8, 8] tensorDeconv4 = self.moduleDeconv4(tensorCombine) #[256, 8, 8] tensorUpsample4 = self.moduleUpsample4(tensorDeconv4) #[256, 16, 16] tensorCombine = tensorUpsample4 + tensorConv4 #[256, 16, 16] tensorDeconv3 = self.moduleDeconv3(tensorCombine) #[128, 16, 16] tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) #[128, 32, 32] tensorCombine = tensorUpsample3 + tensorConv3 #[128, 32, 32] tensorDeconv2 = self.moduleDeconv2(tensorCombine) #[64, 32, 32] tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) #[64, 64, 64] tensorCombine = tensorUpsample2 + tensorConv2 #[64, 64, 64] tensorDot1 = sepconv.FunctionSepconv()( self.modulePad(tensorInput1), self.moduleVertical1(tensorCombine), self.moduleHorizontal1(tensorCombine)) tensorDot2 = sepconv.FunctionSepconv()( self.modulePad(tensorWarp), self.moduleVertical2(tensorCombine), self.moduleHorizontal2(tensorCombine)) return tensorDot1 + tensorDot2
def forward(self, frames): _, f, _, h, w = frames.shape h_padded = False w_padded = False padded_frames = frames.clone() if h % 32 != 0: pad_h = 32 - (h % 32) padded_frames = F.pad(padded_frames, (0, 0, 0, pad_h)) h_padded = True if w % 32 != 0: pad_w = 32 - (w % 32) padded_frames = F.pad(padded_frames, (0, pad_w, 0, 0)) w_padded = True # get kernels from subnets V1, H1, V2, H2, VQ1, HQ1, VQ2, HQ2 = self.interpolation_kernels = self.get_kernel( padded_frames) frame_before = int(0 + f / 4) frame_after = int(1 + f / 4) tensorDotL = sepconv.FunctionSepconv()(self.modulePad_l( padded_frames[:, frame_before]), V1[0], H1[0]) tensorDotR = sepconv.FunctionSepconv()(self.modulePad_l( padded_frames[:, frame_after]), V2[0], H2[0]) if self.kl_d_size != None: # downscale input frames im1d = self.down_l(padded_frames[:, frame_before]) im2d = self.down_l(padded_frames[:, frame_after]) # convolve and upscale back to original size tensorDotL += self.up_l(sepconv.FunctionSepconv()( self.modulePad_ld(im1d), V1[1], H1[1])) tensorDotR += self.up_l(sepconv.FunctionSepconv()( self.modulePad_ld(im2d), V2[1], H2[1])) if self.kq_size != None: tensorDotLL = sepconv.FunctionSepconv()(self.modulePad_q( padded_frames[:, 0]), VQ1[0], HQ1[0]) tensorDotRR = sepconv.FunctionSepconv()(self.modulePad_q( padded_frames[:, 3]), VQ2[0], HQ2[0]) else: tensorDotLL = tensorDotRR = 0 if self.kq_d_size != None: im1qd = self.down_q(padded_frames[:, 0]) im2qd = self.down_q(padded_frames[:, 3]) tensorDotLL += self.up_q(sepconv.FunctionSepconv()( self.modulePad_qd(im1qd), VQ1[1], HQ1[1])) tensorDotRR += self.up_q(sepconv.FunctionSepconv()( self.modulePad_qd(im2qd), VQ2[1], HQ2[1])) frame_out = tensorDotL + tensorDotR + tensorDotLL + tensorDotRR if h_padded: frame_out = frame_out[:, :, 0:h, :] if w_padded: frame_out = frame_out[:, :, :, 0:w] return frame_out