def forward_3_frames(self, x0_pyramid, x1_pyramid, x2_pyramid):
        # outputs
        flows = []

        # init
        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
        init_dtype = x1_pyramid[0].dtype
        init_device = x1_pyramid[0].device
        flow = torch.zeros(b_size, 4, h_x1, w_x1, dtype=init_dtype,
                           device=init_device).float()

        for l, (x0, x1, x2) in enumerate(zip(x0_pyramid, x1_pyramid, x2_pyramid)):
            # warping
            if l == 0:
                x0_warp = x0
                x2_warp = x2
            else:
                flow = F.interpolate(flow * 2, scale_factor=2,
                                     mode='bilinear', align_corners=True)
                x0_warp = flow_warp(x0, flow[:, :2])
                x2_warp = flow_warp(x2, flow[:, 2:])

            # correlation
            corr_10, corr_12 = self.corr(x1, x0_warp), self.corr(x1, x2_warp)
            corr_relu_10, corr_relu_12 = self.leakyRELU(corr_10), self.leakyRELU(corr_12)

            # concat and estimate flow
            x1_1by1 = self.conv_1x1[l](x1)
            feat_10 = [x1_1by1, corr_relu_10, corr_relu_12, flow[:, :2], -flow[:, 2:]]
            feat_12 = [x1_1by1, corr_relu_12, corr_relu_10, flow[:, 2:], -flow[:, :2]]
            x_intm_10, flow_res_10 = self.flow_estimators(torch.cat(feat_10, dim=1))
            x_intm_12, flow_res_12 = self.flow_estimators(torch.cat(feat_12, dim=1))
            flow_res = torch.cat([flow_res_10, flow_res_12], dim=1)
            flow = flow + flow_res

            feat_10 = [x_intm_10, x_intm_12, flow[:, :2], -flow[:, 2:]]
            feat_12 = [x_intm_12, x_intm_10, flow[:, 2:], -flow[:, :2]]
            flow_res_10 = self.context_networks(torch.cat(feat_10, dim=1))
            flow_res_12 = self.context_networks(torch.cat(feat_12, dim=1))
            flow_res = torch.cat([flow_res_10, flow_res_12], dim=1)
            flow = flow + flow_res

            flows.append(flow)

            if l == self.output_level:
                break
        if self.upsample:
            flows = [F.interpolate(flow * 4, scale_factor=4,
                                   mode='bilinear', align_corners=True) for flow in flows]

        flows_10 = [flo[:, :2] for flo in flows[::-1]]
        flows_12 = [flo[:, 2:] for flo in flows[::-1]]
        return flows_10, flows_12
Beispiel #2
0
    def _forward(self, x1_pyramid, x2_pyramid, neg=False):
        flows = []
        for i, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
            if i == 0:
                corr = self.corr(x1, x2)
                feat, flow = self.flow_estimators[i](corr)
                if neg:
                    flow = -F.relu(-flow)
                else:
                    flow = F.relu(flow)
            else:
                # predict the normalized disparity to keep consistent with MonoDepth
                # for reusing the hyper-parameters
                up_flow = F.interpolate(flow,
                                        scale_factor=2,
                                        mode='bilinear',
                                        align_corners=True)

                zeros = torch.zeros_like(up_flow)
                x2_warp = flow_warp(
                    x2,
                    torch.cat([up_flow, zeros], dim=1),
                )

                corr = self.corr(x1, x2_warp)
                F.leaky_relu_(corr)

                feat, flow = self.flow_estimators[i](torch.cat(
                    [corr, x1, up_flow], dim=1))

                flow = flow + up_flow

                if neg:
                    flow = -F.relu(-flow)
                else:
                    flow = F.relu(flow)

                if self.context_networks[i]:
                    flow_fine = self.context_networks[i](torch.cat(
                        [flow, feat], dim=1))
                    flow = flow + flow_fine

                    if neg:
                        flow = -F.relu(-flow)
                    else:
                        flow = F.relu(flow)

            if neg:
                flows.append(-flow)
            else:
                flows.append(flow)
            if len(flows) == self.n_out:
                break
        flows = [
            F.interpolate(flow * 4,
                          scale_factor=4,
                          mode='bilinear',
                          align_corners=True) for flow in flows
        ]
        return flows[::-1]
Beispiel #3
0
def get_occu_mask_bidirection(flow12, flow21, scale=0.1, bias=0.5):
    flow21_warped = flow_warp(flow21, flow12, pad='zeros')
    flow12_diff = flow12 + flow21_warped
    mag = (flow12 * flow12).sum(1, keepdim=True) + \
          (flow21_warped * flow21_warped).sum(1, keepdim=True)
    occ_thresh = scale * mag + bias
    occ = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh
    return occ.float()
Beispiel #4
0
    def forward_2_frames(self, x1_pyramid, x2_pyramid):
        # outputs
        flows = []

        # init
        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
        init_dtype = x1_pyramid[0].dtype
        init_device = x1_pyramid[0].device
        flow = torch.zeros(b_size,
                           2,
                           h_x1,
                           w_x1,
                           dtype=init_dtype,
                           device=init_device).float()

        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):

            # warping
            if l == 0:
                x2_warp = x2
            else:
                flow = F.interpolate(flow * 2,
                                     scale_factor=2,
                                     mode='bilinear',
                                     align_corners=True)
                x2_warp = flow_warp(x2, flow)

            # correlation
            out_corr = self.corr(x1, x2_warp)
            out_corr_relu = self.leakyRELU(out_corr)

            # concat and estimate flow
            x1_1by1 = self.conv_1x1[l](x1)
            x_intm, flow_res = self.flow_estimators(
                torch.cat([out_corr_relu, x1_1by1, flow], dim=1))
            flow = flow + flow_res

            flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
            flow = flow + flow_fine

            flows.append(flow)

            # upsampling or post-processing
            if l == self.output_level:
                break
        if self.upsample:
            flows = [
                F.interpolate(flow * 4,
                              scale_factor=4,
                              mode='bilinear',
                              align_corners=True) for flow in flows
            ]
        return flows[::-1]
Beispiel #5
0
    def _forward(self, x1_pyramid, x2_pyramid):
        flows = []
        for i, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):
            if i == 0:
                corr = self.corr(x1, x2)
                feat, flow = self.flow_estimators[i](corr)
            else:
                up_flow = F.interpolate(flow * 2,
                                        scale_factor=2,
                                        mode='bilinear',
                                        align_corners=True)
                x2_warp = flow_warp(x2, up_flow)
                corr = self.corr(x1, x2_warp)
                F.leaky_relu_(corr)

                flow_feat = [corr, x1, up_flow]
                feat, flow = self.flow_estimators[i](torch.cat(flow_feat,
                                                               dim=1))

                flow = flow + up_flow

                if self.context_networks[i]:
                    flow_fine = self.context_networks[i](torch.cat(
                        [flow, feat], dim=1))
                    flow = flow + flow_fine

            flows.append(flow)
            if len(flows) == self.n_out:
                break
        flows = [
            F.interpolate(flow * 4,
                          scale_factor=4,
                          mode='bilinear',
                          align_corners=True) for flow in flows
        ]
        return flows[::-1]
Beispiel #6
0
    def forward(self, pyramid_disp, fl_bl, pyramid_K, pyramid_K_inv, raw_W,
                pyramid_flow, images):
        """

        :param pyramid_depths: Multi-scale disparities n * [B x h x w]
        :param fl_bl: focal length * baseline [B]
        :param pyramid_K: Multi-scale intrinsics n * [B, 3, 3]
        :param pyramid_K_inv: Multi-scale inverse of intrinsics n * [B, 3, 3]
        :param raw_W: Original width of images [B]
        :param pyramid_flows: Multi-scale forward/backward flows n * [B x 4 x h x w]
        :param target: image pairs Nx6xHxW
        :return:
        """

        B = images.size(0)
        im1_origin = images[:, :3]
        im2_origin = images[:, 3:]

        pyramid_l_photomatric = []
        pyramid_l_smooth = []
        pyramid_l_consistancy = []
        pyramid_l_photomatric_rigid = []
        pyramid_rigid_mask = []

        for i, (disp, flow, K, K_inv, md) in enumerate(
                zip(pyramid_disp, pyramid_flow, pyramid_K, pyramid_K_inv,
                    self.cfg.pyramid_md)):
            # only the first n scales compute loss.
            if i >= self.cfg.valid_s:
                break
            _, _, h, w = flow.size()

            if i == 0 and self.cfg.norm_smooth:
                s = min(h, w)

            disp = F.interpolate(
                disp.unsqueeze(1), (h, w), mode='bilinear',
                align_corners=True).squeeze(1) * raw_W.reshape(-1, 1, 1)

            depth = fl_bl.reshape(-1, 1, 1) / disp.clamp(min=1e-3)  # [B, h ,w]

            # use the largest depth and flow to predict pose
            if i == 0:
                pose_mat, _, inlier_ratio = depth_flow2pose_pt(
                    depth,
                    flow[:, :2],
                    K,
                    K_inv,
                    gs=16,
                    th=2.,
                    method=self.cfg.PnP_method)

            rigid_flow = depth_pose2flow_pt(depth, pose_mat, K, K_inv)

            # resize images to match the size of layer
            im1_scaled = F.interpolate(im1_origin, (h, w), mode='area')
            im2_scaled = F.interpolate(im2_origin, (h, w), mode='area')

            im1_recons, occu_mask1 = flow_warp(im2_scaled, flow[:, :2],
                                               flow[:, 2:])

            im1_recons_rigid = flow_warp(im2_scaled, rigid_flow)

            th_mask = EPE(flow[:, :2], rigid_flow) < self.cfg.mask_th / 2**i

            flow_e = F.pad(SSIM(im1_scaled, im1_recons, md=md),
                           [md] * 4).mean(1, keepdim=True)  # [B, 1, h ,w]
            rigid_e = F.pad(SSIM(im1_scaled, im1_recons_rigid, md=md),
                            [md] * 4).mean(1, keepdim=True)

            dist_e = rigid_e - flow_e
            dist_e = gaussianblur_pt(dist_e, (11, 11), 5)

            delta = percentile_pt(dist_e,
                                  th=self.cfg.recons_p).reshape(-1, 1, 1, 1)
            rigid_mask = dist_e < delta  # [B, 1, h ,w]
            rigid_mask = rigid_mask & th_mask

            # mask out the failure depth region
            rigid_mask = rigid_mask & (depth.unsqueeze(1) < 80)

            # for the failure pose estimation, rigid_mask should be all false
            valid_poses = (inlier_ratio > 0.2).type_as(rigid_mask)
            rigid_mask = rigid_mask & valid_poses.reshape(-1, 1, 1, 1)

            rigid_mask = rigid_mask.float()

            # for the occlusion region, rigid_mask should be true or false
            if self.cfg.mask_with_occu:  # the original tf implementation:
                rigid_mask = (rigid_mask + (occu_mask1 < 0.2).float()).clamp(
                    0., 1.)

            if self.cfg.smooth_mask_by == 'th':
                sm_mask = 1 - (th_mask & (depth.unsqueeze(1) < 80)).float()
            else:
                sm_mask = 1 - rigid_mask  # same as paper

            l_photomatric = self.loss_photomatric(im1_scaled, im1_recons,
                                                  occu_mask1)

            l_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled, sm_mask)

            l_consistancy = self.loss_consistancy(flow[:, :2],
                                                  rigid_flow.detach(),
                                                  rigid_mask)

            # occlusion mask?
            l_photomatric_rigid = self.loss_photomatric(
                im1_scaled, im1_recons_rigid, rigid_mask)

            pyramid_l_photomatric.append(l_photomatric * self.cfg.w_scales[i])
            pyramid_l_smooth.append(l_smooth * self.cfg.w_sm_scales[i])
            pyramid_l_consistancy.append(l_consistancy *
                                         self.cfg.w_cons_scales[i])
            pyramid_l_photomatric_rigid.append(l_photomatric_rigid *
                                               self.cfg.w_rigid_scales[i])

            pyramid_rigid_mask.append(rigid_mask.mean() * B /
                                      (valid_poses.sum() + 1e-6))

        w_l_pohotometric = sum(pyramid_l_photomatric)
        w_l_pohotometric_rigid = sum(pyramid_l_photomatric_rigid)
        w_l_smooth = sum(pyramid_l_smooth)
        w_l_consistancy = sum(pyramid_l_consistancy)

        final_loss = w_l_pohotometric + \
                     self.cfg.w_rigid_warp * w_l_pohotometric_rigid + \
                     self.cfg.w_smooth * w_l_smooth + \
                     self.cfg.w_cons * w_l_consistancy

        return final_loss, w_l_pohotometric, w_l_pohotometric_rigid, \
               1000 * w_l_smooth, w_l_consistancy, \
               sum(pyramid_rigid_mask) / len(pyramid_disp), \
               inlier_ratio.mean()
Beispiel #7
0
    imgs = [imageio.imread(img).astype(np.float32) for img in args.img_list]
    h, w = imgs[0].shape[:2]

    flow_12 = ts.run(imgs)['flows_fw'][0]

    flow_12 = resize_flow(flow_12, (h, w))
    np_flow_12 = flow_12[0].detach().cpu().numpy().transpose([1, 2, 0])

    vis_flow = flowpy.flow_to_rgb(np_flow_12)

    cv2.imwrite(t0[0]+ " " +t1[0]+".png", vis_flow)

    im1=cv2.imread("/content/drive/MyDrive/data1/NATL_AN_2007-01-03.png")
    im2=cv2.imread("/content/drive/MyDrive/data1/NATL_AN_2007-01-04.png ")
    
    im1=flow_warp(im2, flow12, pad='border', mode='bilinear'):
    
    
    cv2.imwrite("warped" +t0[0]+ " " +t1[0]+".png", warped)
    def PSNR(original, compressed): 
      mse = np.mean((original - compressed) ** 2) 
      if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                    # Therefore PSNR have no importance. 
          return 100
      max_pixel = 255.0
      psnr = 20 * log10(max_pixel / sqrt(mse)) 
      return psnr 
  # 5. Compute the Structural Similarity Index (SSIM) between the two
  #    images, ensuring that the difference image is returned

Beispiel #8
0
    def forward(self, output, target):
        """

        :param output: Multi-scale forward/backward flows n * [B x 4 x h x w]
        :param target: image pairs Nx6xHxW
        :return:
        """

        pyramid_flows = output
        im1_origin = target[:, :3]
        im2_origin = target[:, 3:]

        pyramid_smooth_losses = []
        pyramid_warp_losses = []
        self.pyramid_occu_mask1 = []
        self.pyramid_occu_mask2 = []

        s = 1.
        for i, flow in enumerate(pyramid_flows):
            if self.cfg.w_scales[i] == 0:
                pyramid_warp_losses.append(0)
                pyramid_smooth_losses.append(0)
                continue

            b, _, h, w = flow.size()

            # resize images to match the size of layer
            im1_scaled = F.interpolate(im1_origin, (h, w), mode='area')
            im2_scaled = F.interpolate(im2_origin, (h, w), mode='area')

            im1_recons = flow_warp(im2_scaled,
                                   flow[:, :2],
                                   pad=self.cfg.warp_pad)
            im2_recons = flow_warp(im1_scaled,
                                   flow[:, 2:],
                                   pad=self.cfg.warp_pad)

            if i == 0:
                if self.cfg.occ_from_back:
                    occu_mask1 = 1 - get_occu_mask_backward(flow[:, 2:],
                                                            th=0.2)
                    occu_mask2 = 1 - get_occu_mask_backward(flow[:, :2],
                                                            th=0.2)
                else:
                    occu_mask1 = 1 - get_occu_mask_bidirection(
                        flow[:, :2], flow[:, 2:])
                    occu_mask2 = 1 - get_occu_mask_bidirection(
                        flow[:, 2:], flow[:, :2])
            else:
                occu_mask1 = F.interpolate(self.pyramid_occu_mask1[0], (h, w),
                                           mode='nearest')
                occu_mask2 = F.interpolate(self.pyramid_occu_mask2[0], (h, w),
                                           mode='nearest')

            self.pyramid_occu_mask1.append(occu_mask1)
            self.pyramid_occu_mask2.append(occu_mask2)

            loss_warp = self.loss_photomatric(im1_scaled, im1_recons,
                                              occu_mask1)

            if i == 0:
                s = min(h, w)

            loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled)

            if self.cfg.with_bk:
                loss_warp += self.loss_photomatric(im2_scaled, im2_recons,
                                                   occu_mask2)
                loss_smooth += self.loss_smooth(flow[:, 2:] / s, im2_scaled)

                loss_warp /= 2.
                loss_smooth /= 2.

            pyramid_warp_losses.append(loss_warp)
            pyramid_smooth_losses.append(loss_smooth)

        pyramid_warp_losses = [
            l * w for l, w in zip(pyramid_warp_losses, self.cfg.w_scales)
        ]
        pyramid_smooth_losses = [
            l * w for l, w in zip(pyramid_smooth_losses, self.cfg.w_sm_scales)
        ]

        warp_loss = sum(pyramid_warp_losses)
        smooth_loss = self.cfg.w_smooth * sum(pyramid_smooth_losses)
        total_loss = warp_loss + smooth_loss

        return total_loss, warp_loss, smooth_loss, pyramid_flows[0].abs().mean(
        )
Beispiel #9
0
        h, w = imgs[0].shape[:2]

        res_dict = ts.run(imgs)
        flow_12 = res_dict['flows_fw'][0]
        flow_21 = res_dict['flows_bw'][0]

        flow_12 = resize_flow(flow_12, (h, w))  # [1, 2, H, W]
        flow_21 = resize_flow(flow_21, (h, w))  # [1, 2, H, W]
        occu_mask1 = 1 - get_occu_mask_bidirection(flow_12,
                                                   flow_21)  # [1, 1, H, W]
        occu_mask2 = 1 - get_occu_mask_bidirection(flow_21, flow_12)
        back_occu_mask1 = get_occu_mask_backward(flow_21)
        back_occu_mask2 = get_occu_mask_backward(flow_21)

        warped_image_12 = flow_warp(torch.from_numpy(
            np.transpose(imgs[1], [2, 0, 1])).unsqueeze(0).cuda(),
                                    flow_12,
                                    pad='border')
        warped_image_21 = flow_warp(torch.from_numpy(
            np.transpose(imgs[0], [2, 0, 1])).unsqueeze(0).cuda(),
                                    flow_21,
                                    pad='border')
        np_warped_image12 = warped_image_12[0].detach().cpu().numpy(
        ).transpose([1, 2, 0])
        np_warped_image21 = warped_image_21[0].detach().cpu().numpy(
        ).transpose([1, 2, 0])

        np_flow_12 = flow_12[0].detach().cpu().numpy().transpose([1, 2, 0])
        np_flow_21 = flow_21[0].detach().cpu().numpy().transpose([1, 2, 0])
        # vx = np_flow_12[:, :, 0]
        # vy = np_flow_12[:, :, 1]
        # f = open(os.path.join(r'G:\ARFlow-master\data\flow_dataset\ceshi_tmp', name + '_vx.bin'), 'wb')
Beispiel #10
0
    def forward(self, output, target):
        """

        :param output: Multi-scale forward/backward flows n * [B x 4 x h x w]
        :param target: image pairs Nx6xHxW
        :return:
        """

        pyramid_flows = output
        im1_origin = target[:, :3]
        im2_origin = target[:, 3:]

        pyramid_smooth_losses = []
        pyramid_warp_losses = []
        self.pyramid_occu_mask1 = []
        self.pyramid_occu_mask2 = []

        s = 1.
        for i, flow in enumerate(pyramid_flows):
            b, _, h, w = flow.size()

            # resize images to match the size of layer
            im1_scaled = F.interpolate(im1_origin, (h, w), mode='area')
            im2_scaled = F.interpolate(im2_origin, (h, w), mode='area')

            im1_recons, occu_mask1 = flow_warp(im2_scaled, flow[:, :2],
                                               flow[:, 2:])
            im2_recons, occu_mask2 = flow_warp(im1_scaled, flow[:, 2:],
                                               flow[:, :2])

            self.pyramid_occu_mask1.append(occu_mask1)
            self.pyramid_occu_mask2.append(occu_mask2)

            if self.cfg.hard_occu:
                occu_mask1 = (occu_mask1 > self.cfg.hard_occu_th).float()
                occu_mask2 = (occu_mask2 > self.cfg.hard_occu_th).float()

            loss_photomatric = self.loss_photomatric(im1_scaled, im1_recons,
                                                     occu_mask1)

            if i == 0 and self.cfg.norm_smooth:
                s = min(h, w)

            if self.cfg.s_mask:
                loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled,
                                               occu_mask1)
            else:
                loss_smooth = self.loss_smooth(flow[:, :2] / s, im1_scaled,
                                               None)

            if self.cfg.with_bk:
                loss_photomatric += self.loss_photomatric(
                    im2_scaled, im2_recons, occu_mask2)

                if self.cfg.s_mask:
                    loss_smooth += self.loss_smooth(flow[:, 2:] / s,
                                                    im2_scaled, occu_mask2)
                else:
                    loss_smooth += self.loss_smooth(flow[:, 2:] / s,
                                                    im2_scaled, None)
                loss_photomatric /= 2.
                loss_smooth /= 2.

            pyramid_warp_losses.append(loss_photomatric)
            pyramid_smooth_losses.append(loss_smooth)

        pyramid_warp_losses = [
            l * w for l, w in zip(pyramid_warp_losses, self.cfg.w_scales)
        ]
        pyramid_smooth_losses = [
            l * w for l, w in zip(pyramid_smooth_losses, self.cfg.w_sm_scales)
        ]

        return sum(pyramid_warp_losses) + self.cfg.w_smooth * sum(pyramid_smooth_losses), \
               sum(pyramid_warp_losses), self.cfg.w_smooth * sum(pyramid_smooth_losses), \
               pyramid_flows[0].abs().mean()
Beispiel #11
0
    def _validate_with_gt2(self):
        import cv2
        import torch.nn.functional as F
        from utils.warp_utils import flow_warp
        from utils.misc_utils import plot_imgs

        batch_time = AverageMeter()

        error_names = ['EPE', 'E_noc', 'E_occ', 'F1_all']
        error_meters = AverageMeter(i=len(error_names))

        self.model.eval()
        self.model = self.model.float()
        end = time.time()
        for i_step, data in enumerate(self.valid_loader):
            img1, img2 = data['img1'], data['img2']
            img_pair = torch.cat([img1, img2], 1).to(self.device)

            # compute output
            flow = self.model(img_pair, with_bk=True)[0]
            _, _, h, w = flow.size()

            im1_origin = img_pair[:, :3]
            _, occu_mask1 = flow_warp(im1_origin, flow[:, :2], flow[:, 2:])

            res = list(map(load_flow, data['flow_occ']))
            gt_flows, occ_masks = [r[0] for r in res], [r[1] for r in res]
            res = list(map(load_flow, data['flow_noc']))
            _, noc_masks = [r[0] for r in res], [r[1] for r in res]

            gt_flows = [np.concatenate([flow, occ_mask, noc_mask], axis=2) for
                        flow, occ_mask, noc_mask in zip(gt_flows, occ_masks, noc_masks)]
            pred_flows = flow[:, :2].detach().cpu().numpy().transpose([0, 2, 3, 1])
            es = evaluate_kitti_flow(gt_flows, pred_flows)
            error_meters.update([l.item() for l in es], img_pair.size(0))

            plot_list = []
            occu_mask1 = (occu_mask1 < 0.2).detach().cpu().numpy()[0, 0] * 255
            plot_list.append({'im': occu_mask1, 'title': 'occu mask 1'})

            gt_occu_mask1 = (noc_masks[0] - occ_masks[0])[:, :, 0].astype(
                np.float32) * 255
            plot_list.append({'im': gt_occu_mask1, 'title': 'gt occu mask 1'})
            plot_imgs(plot_list,
                      save_path='./tmp/occu_soft_hard/occu_hard_{:03d}.jpg'.format(
                          i_step))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i_step % self.cfg.print_freq == 0:
                self._log.info('Test: [{0}/{1}]\t Time {2}\t '.format(
                    i_step, self.cfg.valid_size, batch_time) + ' '.join(
                    map('{:.2f}'.format, error_meters.avg)))

            if i_step > self.cfg.valid_size:
                break

        # write error to tf board.
        for value, name in zip(error_meters.avg, error_names):
            self.summary_writer.add_scalar('Valid_' + name, value, self.i_epoch)

        # In order to reduce the space occupied during debugging,
        # only the model with more than cfg.save_iter iterations will be saved.
        if self.i_iter > self.cfg.save_iter:
            self.save_model(error_meters.avg[0], 'KITTI_flow')

        return error_meters.avg, error_names
    def forward_2_frames(self, x1_pyramid, x2_pyramid):
        # outputs
        flows = []

        # init
        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()
        init_dtype = x1_pyramid[0].dtype
        init_device = x1_pyramid[0].device
        flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype,
                           device=init_device).float()


        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):

            #print(l)
            # print(x1.shape)
            # Output level is 4
            # 0
            # torch.Size([2, 192, 6, 13])
            # 1
            # torch.Size([2, 128, 12, 26])
            # 2
            # torch.Size([2, 96, 24, 52])
            # 3
            # torch.Size([2, 64, 48, 104])
            # 4
            # torch.Size([2, 32, 96, 208])


            # warping
            if l == 0:
                x2_warp = x2
            else:
                flow = F.interpolate(flow * 2, scale_factor=2,
                                     mode='bilinear', align_corners=True)
                x2_warp = flow_warp(x2, flow)

            # correlation - checks the x1 against x2_warped == x1
            #print(x1.shape)
            #print(x2_warp.shape)
            out_corr = self.corr(x1, x2_warp)
            out_corr_relu = self.leakyRELU(out_corr)
            #print(out_corr_relu.shape)
            #print("--")

            # 0
            # torch.Size([2, 192, 6, 13]) in
            # torch.Size([2, 192, 6, 13]) in
            # torch.Size([2, 81, 6, 13]) out - seems to be 81 for corr
            # --
            # 1
            # torch.Size([2, 128, 12, 26])
            # torch.Size([2, 128, 12, 26])
            # torch.Size([2, 81, 12, 26])
            # --
            # 2
            # torch.Size([2, 96, 24, 52])
            # torch.Size([2, 96, 24, 52])
            # torch.Size([2, 81, 24, 52])
            # --
            # 3
            # torch.Size([2, 64, 48, 104])
            # torch.Size([2, 64, 48, 104])
            # torch.Size([2, 81, 48, 104])

            # concat and estimate flow
            x1_1by1 = self.conv_1x1[l](x1) # Compresses Channels to 32
            x_intm, flow_res = self.flow_estimators(
                torch.cat([out_corr_relu, x1_1by1, flow], dim=1))
            flow = flow + flow_res

            flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))
            flow = flow + flow_fine

            #print(flow.shape)

            flows.append(flow)

            # upsampling or post-processing
            if l == self.output_level:
                break
        if self.upsample:
            flows = [F.interpolate(flow * 4, scale_factor=4,
                                   mode='bilinear', align_corners=True) for flow in flows]
        return flows[::-1]