def forward(self, x, flow): x = self.conv1(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f1 = warp(x, flow) x = self.conv2(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f2 = warp(x, flow) x = self.conv3(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f3 = warp(x, flow) x = self.conv4(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f4 = warp(x, flow) return [f1, f2, f3, f4]
def forward(self, x, scale_list=[4, 2, 1], training=False, ensemble=True, ada=True): channel = x.shape[1] // 2 img0 = x[:, :channel] img1 = x[:, channel:] flow_list = [] merged = [] mask_list = [] warped_img0 = img0 warped_img1 = img1 flow = torch.zeros_like(x[:, :4]).to(device) mask = torch.zeros_like(x[:, :1]).to(device) loss_cons = 0 block = [self.block0, self.block1, self.block2] for i in range(3): f0, m0 = block[i](torch.cat( (warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) if ensemble: f1, m1 = block[i](torch.cat( (warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 m0 = (m0 + (-m1)) / 2 flow = flow + f0 # TODO all dark frame output in 3.8 model mask = mask + m0 mask_list.append(mask) flow_list.append(flow) warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) merged.append((warped_img0, warped_img1)) ''' c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) res = tmp[:, 1:4] * 2 - 1 ''' mask_list[2] = torch.sigmoid(mask_list[2]) merged[2] = merged[2][0] * mask_list[2] + merged[2][1] * (1 - mask_list[2]) return merged, flow
def forward(self, img0, img1, flow, c0, c1, flow_gt): warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) if flow_gt == None: warped_img0_gt, warped_img1_gt = None, None else: warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img1_gt = warp(img1, flow_gt[:, 2:4]) x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) s0 = self.down0(x) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) x = self.up1(torch.cat((x, s2), 1)) x = self.up2(torch.cat((x, s1), 1)) x = self.up3(torch.cat((x, s0), 1)) x = self.conv(x) return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
def forward(self, x, scale=1.0, ensemble=False, ada=True): if scale != 1.0: x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) flow0 = self.block0(torch.cat((x[:, :3], x[:, 3:]), 1)) if ensemble: flow01 = self.block0(torch.cat((x[:, 3:], x[:, :3]), 1)) flow0 = (flow0 + torch.cat((flow01[:, 2:4], flow01[:, :2]), 1)) / 2 F1 = flow0 F1_large = F.interpolate( F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 warped_img0 = warp(x[:, :3], F1_large[:, :2]) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) if ensemble: F1_large = torch.cat((F1_large[:, 2:4], F1_large[:, :2]), 1) flow11 = self.block1( torch.cat((warped_img1, warped_img0, F1_large), 1)) flow1 = (flow1 + torch.cat((flow11[:, 2:4], flow11[:, :2]), 1)) / 2 F2 = (flow0 + flow1) F2_large = F.interpolate( F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) if ensemble: F2_large = torch.cat((F2_large[:, 2:4], F2_large[:, :2]), 1) flow21 = self.block2( torch.cat((warped_img1, warped_img0, F2_large), 1)) flow2 = (flow2 + torch.cat((flow21[:, 2:4], flow21[:, :2]), 1)) / 2 F3 = (flow0 + flow1 + flow2) F3_large = F.interpolate( F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 warped_img0 = warp(x[:, :3], F3_large[:, :2]) warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) if ensemble: F3_large = torch.cat((F3_large[:, 2:4], F3_large[:, :2]), 1) flow31 = self.block3( torch.cat((warped_img1, warped_img0, F3_large), 1)) flow3 = (flow3 + torch.cat((flow31[:, 2:4], flow31[:, :2]), 1)) / 2 F4 = (flow0 + flow1 + flow2 + flow3) if scale != 1.0: F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale return F4, [F1, F2, F3, F4]
def forward(self, x, scale=[4, 2, 1], ensemble=False): img0 = x[:, :3] img1 = x[:, 3:6] gt = x[:, 6:] # In inference time, gt is None flow_list = [] merged = [] mask_list = [] warped_img0 = img0 warped_img1 = img1 flow = None loss_distill = 0 stu = [self.block0, self.block1, self.block2] for i in range(3): if flow != None: flow_d0, mask_d0 = stu[i](torch.cat( (img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]) if ensemble: flow_d1, mask_d1 = stu[i](torch.cat( (img1, img0, warped_img1, warped_img0, -mask), 1), torch.cat( (flow[:, 2:4], flow[:, :2]), 1), scale=scale[i]) flow_d0 = (flow_d0 + torch.cat( (flow_d1[:, 2:4], flow_d1[:, :2]), 1)) / 2 mask_d0 = (mask_d0 + (-mask_d1)) / 2 flow = flow + flow_d0 mask = mask + mask_d0 else: flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i]) if ensemble: flow1, mask1 = stu[i](torch.cat((img1, img0), 1), None, scale=scale[i]) flow = (flow + torch.cat( (flow1[:, 2:4], flow1[:, :2]), 1)) / 2 mask = (mask + (-mask1)) / 2 mask_list.append(torch.sigmoid(mask)) flow_list.append(flow) warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) merged_student = (warped_img0, warped_img1) merged.append(merged_student) if gt.shape[1] == 3: flow_d, mask_d = self.block_tea(torch.cat( (img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1) flow_teacher = flow + flow_d warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) mask_teacher = torch.sigmoid(mask + mask_d) merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * ( 1 - mask_teacher) else: flow_teacher = None merged_teacher = None for i in range(3): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * ( 1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean() c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) res = tmp[:, :3] * 2 - 1 merged[2] = torch.clamp(merged[2] + res, 0, 1) return flow_list, mask_list[ 2], merged, flow_teacher, merged_teacher, loss_distill