def forward(self, x, flow):
     x = self.conv1(x)
     flow = F.interpolate(flow,
                          scale_factor=0.5,
                          mode="bilinear",
                          align_corners=False,
                          recompute_scale_factor=False) * 0.5
     f1 = warp(x, flow)
     x = self.conv2(x)
     flow = F.interpolate(flow,
                          scale_factor=0.5,
                          mode="bilinear",
                          align_corners=False,
                          recompute_scale_factor=False) * 0.5
     f2 = warp(x, flow)
     x = self.conv3(x)
     flow = F.interpolate(flow,
                          scale_factor=0.5,
                          mode="bilinear",
                          align_corners=False,
                          recompute_scale_factor=False) * 0.5
     f3 = warp(x, flow)
     x = self.conv4(x)
     flow = F.interpolate(flow,
                          scale_factor=0.5,
                          mode="bilinear",
                          align_corners=False,
                          recompute_scale_factor=False) * 0.5
     f4 = warp(x, flow)
     return [f1, f2, f3, f4]
    def forward(self,
                x,
                scale_list=[4, 2, 1],
                training=False,
                ensemble=True,
                ada=True):
        channel = x.shape[1] // 2
        img0 = x[:, :channel]
        img1 = x[:, channel:]
        flow_list = []
        merged = []
        mask_list = []
        warped_img0 = img0
        warped_img1 = img1
        flow = torch.zeros_like(x[:, :4]).to(device)
        mask = torch.zeros_like(x[:, :1]).to(device)
        loss_cons = 0
        block = [self.block0, self.block1, self.block2]

        for i in range(3):
            f0, m0 = block[i](torch.cat(
                (warped_img0[:, :3], warped_img1[:, :3], mask), 1),
                              flow,
                              scale=scale_list[i])
            if ensemble:
                f1, m1 = block[i](torch.cat(
                    (warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
                                  torch.cat((flow[:, 2:4], flow[:, :2]), 1),
                                  scale=scale_list[i])
                f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
                m0 = (m0 + (-m1)) / 2
            flow = flow + f0  # TODO all dark frame output in 3.8 model
            mask = mask + m0
            mask_list.append(mask)
            flow_list.append(flow)
            warped_img0 = warp(img0, flow[:, :2])
            warped_img1 = warp(img1, flow[:, 2:4])
            merged.append((warped_img0, warped_img1))
        '''
        c0 = self.contextnet(img0, flow[:, :2])
        c1 = self.contextnet(img1, flow[:, 2:4])
        tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
        res = tmp[:, 1:4] * 2 - 1
        '''
        mask_list[2] = torch.sigmoid(mask_list[2])
        merged[2] = merged[2][0] * mask_list[2] + merged[2][1] * (1 -
                                                                  mask_list[2])
        return merged, flow
 def forward(self, img0, img1, flow, c0, c1, flow_gt):
     warped_img0 = warp(img0, flow[:, :2])
     warped_img1 = warp(img1, flow[:, 2:4])
     if flow_gt == None:
         warped_img0_gt, warped_img1_gt = None, None
     else:
         warped_img0_gt = warp(img0, flow_gt[:, :2])
         warped_img1_gt = warp(img1, flow_gt[:, 2:4])
     x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
     s0 = self.down0(x)
     s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
     s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
     s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
     x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
     x = self.up1(torch.cat((x, s2), 1))
     x = self.up2(torch.cat((x, s1), 1))
     x = self.up3(torch.cat((x, s0), 1))
     x = self.conv(x)
     return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
Esempio n. 4
0
 def forward(self, x, scale=1.0, ensemble=False, ada=True):
     if scale != 1.0:
         x = F.interpolate(x,
                           scale_factor=scale,
                           mode="bilinear",
                           align_corners=False)
     flow0 = self.block0(torch.cat((x[:, :3], x[:, 3:]), 1))
     if ensemble:
         flow01 = self.block0(torch.cat((x[:, 3:], x[:, :3]), 1))
         flow0 = (flow0 + torch.cat((flow01[:, 2:4], flow01[:, :2]), 1)) / 2
     F1 = flow0
     F1_large = F.interpolate(
         F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
     warped_img0 = warp(x[:, :3], F1_large[:, :2])
     warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
     flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
     if ensemble:
         F1_large = torch.cat((F1_large[:, 2:4], F1_large[:, :2]), 1)
         flow11 = self.block1(
             torch.cat((warped_img1, warped_img0, F1_large), 1))
         flow1 = (flow1 + torch.cat((flow11[:, 2:4], flow11[:, :2]), 1)) / 2
     F2 = (flow0 + flow1)
     F2_large = F.interpolate(
         F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
     warped_img0 = warp(x[:, :3], F2_large[:, :2])
     warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
     flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
     if ensemble:
         F2_large = torch.cat((F2_large[:, 2:4], F2_large[:, :2]), 1)
         flow21 = self.block2(
             torch.cat((warped_img1, warped_img0, F2_large), 1))
         flow2 = (flow2 + torch.cat((flow21[:, 2:4], flow21[:, :2]), 1)) / 2
     F3 = (flow0 + flow1 + flow2)
     F3_large = F.interpolate(
         F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
     warped_img0 = warp(x[:, :3], F3_large[:, :2])
     warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
     flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
     if ensemble:
         F3_large = torch.cat((F3_large[:, 2:4], F3_large[:, :2]), 1)
         flow31 = self.block3(
             torch.cat((warped_img1, warped_img0, F3_large), 1))
         flow3 = (flow3 + torch.cat((flow31[:, 2:4], flow31[:, :2]), 1)) / 2
     F4 = (flow0 + flow1 + flow2 + flow3)
     if scale != 1.0:
         F4 = F.interpolate(F4,
                            scale_factor=1 / scale,
                            mode="bilinear",
                            align_corners=False) / scale
     return F4, [F1, F2, F3, F4]
Esempio n. 5
0
 def forward(self, x, scale=[4, 2, 1], ensemble=False):
     img0 = x[:, :3]
     img1 = x[:, 3:6]
     gt = x[:, 6:]  # In inference time, gt is None
     flow_list = []
     merged = []
     mask_list = []
     warped_img0 = img0
     warped_img1 = img1
     flow = None
     loss_distill = 0
     stu = [self.block0, self.block1, self.block2]
     for i in range(3):
         if flow != None:
             flow_d0, mask_d0 = stu[i](torch.cat(
                 (img0, img1, warped_img0, warped_img1, mask), 1),
                                       flow,
                                       scale=scale[i])
             if ensemble:
                 flow_d1, mask_d1 = stu[i](torch.cat(
                     (img1, img0, warped_img1, warped_img0, -mask), 1),
                                           torch.cat(
                                               (flow[:, 2:4], flow[:, :2]),
                                               1),
                                           scale=scale[i])
                 flow_d0 = (flow_d0 + torch.cat(
                     (flow_d1[:, 2:4], flow_d1[:, :2]), 1)) / 2
                 mask_d0 = (mask_d0 + (-mask_d1)) / 2
             flow = flow + flow_d0
             mask = mask + mask_d0
         else:
             flow, mask = stu[i](torch.cat((img0, img1), 1),
                                 None,
                                 scale=scale[i])
             if ensemble:
                 flow1, mask1 = stu[i](torch.cat((img1, img0), 1),
                                       None,
                                       scale=scale[i])
                 flow = (flow + torch.cat(
                     (flow1[:, 2:4], flow1[:, :2]), 1)) / 2
                 mask = (mask + (-mask1)) / 2
         mask_list.append(torch.sigmoid(mask))
         flow_list.append(flow)
         warped_img0 = warp(img0, flow[:, :2])
         warped_img1 = warp(img1, flow[:, 2:4])
         merged_student = (warped_img0, warped_img1)
         merged.append(merged_student)
     if gt.shape[1] == 3:
         flow_d, mask_d = self.block_tea(torch.cat(
             (img0, img1, warped_img0, warped_img1, mask, gt), 1),
                                         flow,
                                         scale=1)
         flow_teacher = flow + flow_d
         warped_img0_teacher = warp(img0, flow_teacher[:, :2])
         warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
         mask_teacher = torch.sigmoid(mask + mask_d)
         merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
             1 - mask_teacher)
     else:
         flow_teacher = None
         merged_teacher = None
     for i in range(3):
         merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (
             1 - mask_list[i])
         if gt.shape[1] == 3:
             loss_mask = ((merged[i] - gt).abs().mean(1, True) >
                          (merged_teacher - gt).abs().mean(1, True) +
                          0.01).float().detach()
             loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() *
                              loss_mask).mean()
     c0 = self.contextnet(img0, flow[:, :2])
     c1 = self.contextnet(img1, flow[:, 2:4])
     tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0,
                     c1)
     res = tmp[:, :3] * 2 - 1
     merged[2] = torch.clamp(merged[2] + res, 0, 1)
     return flow_list, mask_list[
         2], merged, flow_teacher, merged_teacher, loss_distill