예제 #1
0
 def _forward_single_image(self, left_prediction: BoxList,
                           right_prediction: BoxList) -> DisparityMap:
     left_bbox = left_prediction.bbox
     right_bbox = right_prediction.bbox
     disparity_preds = left_prediction.get_field('disparity')
     mask_preds = left_prediction.get_field('mask').clone()
     # print(disparity_preds.shape)
     assert len(left_bbox) == len(right_bbox) == len(
         disparity_preds
     ), f'{len(left_bbox), len(right_bbox), len(disparity_preds)}'
     num_rois = len(left_bbox)
     if num_rois == 0:
         disparity_full_image = torch.zeros(
             (left_prediction.height, left_prediction.width))
     else:
         disparity_maps = []
         for left_roi, right_roi, disp_roi, mask_pred in zip(
                 left_bbox, right_bbox, disparity_preds, mask_preds):
             x1, y1, x2, y2 = left_roi.tolist()
             x1p, _, x2p, _ = right_roi.tolist()
             x1, y1, x2, y2 = expand_box_to_integer((x1, y1, x2, y2))
             x1p, _, x2p, _ = expand_box_to_integer((x1p, y1, x2p, y2))
             disparity_map_per_roi = torch.zeros(
                 (left_prediction.height, left_prediction.width))
             # mask = mask_pred.squeeze(0)
             # mask = SegmentationMask(BinaryMaskList(mask, size=mask.shape[::-1]), size=mask.shape[::-1],
             #                         mode='mask').crop((x1, y1, x1 + max(x2 - x1, x2p - x1p), y2))
             disp_roi = DisparityMap(disp_roi).resize(
                 (max(x2 - x1, x2p - x1p), y2 - y1)).crop(
                     (0, 0, x2 - x1, y2 - y1)).data
             disp_roi = disp_roi + x1 - x1p
             disparity_map_per_roi[y1:y2, x1:x2] = disp_roi
             disparity_maps.append(disparity_map_per_roi)
         disparity_full_image = torch.stack(disparity_maps).max(dim=0)[0]
     return DisparityMap(disparity_full_image)
예제 #2
0
 def process_input_eval(self, left_inputs, right_inputs, targets, threshold=0.7, padding=1):
     depth_maps = []
     mask_pred_list = []
     fus = []
     for left_prediction, right_prediction, target in zip(left_inputs, right_inputs, targets):
         left_bbox = left_prediction.bbox
         right_bbox = right_prediction.bbox
         disparity_preds = left_prediction.get_field('disparity')
         masks = left_prediction.get_field('mask')
         masker = Masker(threshold=threshold, padding=padding)
         mask_pred = masker([masks], [left_prediction])[0].squeeze(1)
         # assert len(left_bbox) == len(right_bbox) == len(
         #     disparity_preds), f'{len(left_bbox), len(right_bbox), len(disparity_preds)}'
         num_rois = len(left_bbox)
         fus.extend([target.get_field('calib').calib.fu for _ in range(num_rois)])
         depth_maps_per_img = []
         disparity_maps_per_img = []
         if num_rois != 0:
             for left_roi, right_roi, disp_or_depth_roi, mask_p in zip(left_bbox, right_bbox,
                                                                       disparity_preds, mask_pred):
                 x1, y1, x2, y2 = expand_box_to_integer(left_roi.tolist())
                 x1p, _, x2p, _ = expand_box_to_integer(right_roi.tolist())
                 depth_map_per_roi = torch.zeros((left_prediction.height, left_prediction.width)).cuda()
                 disparity_map_per_roi = torch.zeros_like(depth_map_per_roi)
                 mask = mask_p.squeeze(0)
                 disp_roi = DisparityMap(disp_or_depth_roi).resize(
                     (max(x2 - x1, x2p - x1p), y2 - y1)).crop(
                     (0, 0, x2 - x1, y2 - y1)).data
                 disp_roi = disp_roi + x1 - x1p
                 depth_roi = target.get_field('calib').stereo_fuxbaseline / (disp_roi + 1e-6)
                 depth_map_per_roi[y1:y2, x1:x2] = depth_roi.clamp(min=1.0)
                 disparity_map_per_roi[y1:y2, x1:x2] = disp_roi
                 disparity_map_per_roi = disparity_map_per_roi * mask.float().cuda()
                 # imageio.imsave('~/code/disprcnn_plus/tmp.jpg', depth_map_per_roi.cpu().numpy())
                 depth_maps_per_img.append(depth_map_per_roi)
                 disparity_maps_per_img.append(disparity_map_per_roi)
             if len(depth_maps_per_img) != 0:
                 depth_maps_per_img = torch.stack(depth_maps_per_img)
                 disparity_maps_per_img = torch.stack(disparity_maps_per_img).sum(dim=0)
             else:
                 depth_maps_per_img = torch.zeros((1, left_prediction.height, left_prediction.width))
                 disparity_maps_per_img = torch.zeros((left_prediction.height, left_prediction.width))
             depth_maps.append(depth_maps_per_img)
             mask_pred_list.append(mask_pred.cuda())
     if len(depth_maps) != 0:
         fus = torch.tensor(fus).cuda()
         self.rotator = rotate_pc_along_y(left_inputs, fus)
         pts = self.back_project(depth_maps, mask_pred_list, targets=targets, fix_seed=True)
         pts = self.rotator.__call__(pts.permute(0, 2, 1)).permute(0, 2, 1) # Transformation of view cone of point cloud
         # pts_tmp = pts.cpu().numpy()
         # with open('/home/liangzx/code/disprcnn_plus/tmp2.obj', 'w+') as f:
         #     for i in range(pts_tmp.shape[1]):
         #             f.write("v" + " " + str(pts_tmp[0,i,0]) + " " + str(pts_tmp[0,i,1]) + " " + str(pts_tmp[0,i,2]) + "\n")
         pts_mean = pts.mean(1)
         self.pts_mean = pts_mean
         pts = pts - pts_mean[:, None, :]
     else:
         pts = torch.empty((0, 768, 3)).cuda()
     return pts
예제 #3
0
    def roi_disp_postprocess(self, left_result: List[BoxList],
                             right_result: List[BoxList],
                             output: torch.Tensor):
        output_splited = torch.split(output, [len(a) for a in left_result])
        for lr, rr, out in zip(left_result, right_result, output_splited):
            # each image
            roi_disps_per_img = []
            mask_preds_per_img = self.masker([lr.get_field('mask')],
                                             [lr])[0].squeeze(1)
            if mask_preds_per_img.ndimension() == 2:
                mask_preds_per_img = mask_preds_per_img.unsqueeze(0)
            for i, (leftbox, rightbox, mask_pred) in enumerate(
                    zip(lr.bbox.tolist(), rr.bbox.tolist(),
                        mask_preds_per_img)):
                x1, y1, x2, y2 = expand_box_to_integer(leftbox)
                x1p, _, x2p, _ = expand_box_to_integer(rightbox)
                roi_disp = DisparityMap(out[i]).resize(
                    (max(x2 - x1, x2p - x1p), y2 - y1)).crop(
                        (0, 0, x2 - x1, y2 - y1))
                disparity_map_per_roi = torch.zeros((lr.height, lr.width))
                disparity_map_per_roi[int(y1):int(y1) + roi_disp.height,
                                      int(x1):int(x1) +
                                      roi_disp.width] = roi_disp.data + (x1 -
                                                                         x1p)
                disparity_map_per_roi = disparity_map_per_roi.clone().clamp(
                    min=0)  # clip to 0.
                disparity_map_per_roi = disparity_map_per_roi * mask_pred.float(
                )
                roi_disps_per_img.append(disparity_map_per_roi)
            if len(roi_disps_per_img) != 0:
                roi_disps_per_img = torch.stack(roi_disps_per_img).cuda().max(
                    dim=0)[0]
            else:
                roi_disps_per_img = torch.zeros((lr.height, lr.width))
            # print(roi_disps_per_img.max(),roi_disps_per_img.min())
            # lr.add_field('disparity_full_img_size', roi_disps_per_img)
            lr.add_map('disparity', roi_disps_per_img)

        return left_result
예제 #4
0
def main():
    args = parser.parse_args()
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    root = os.path.join(PROJECT_ROOT, 'data/kitti')
    splits = ['train', 'val']
    masker = Masker(args.masker_thresh)
    for split in splits:
        prediction_pth = args.prediction_template % split
        predictions = torch.load(prediction_pth)
        left_predictions, right_predictions = predictions['left'], predictions[
            'right']
        os.makedirs(os.path.join(output_dir, split, 'image', 'left'),
                    exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'image', 'right'),
                    exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'label'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'disparity'),
                    exist_ok=True)
        ds = KITTIObjectDataset(root,
                                split,
                                filter_empty=False,
                                mask_disp_sub_path=args.mask_disp_sub_path)
        wrote = 0
        assert len(left_predictions) == len(ds)
        for i, (images, targets, _) in enumerate(tqdm(ds)):
            leftimg, rightimg = images['left'], images['right']
            leftanno, rightanno = targets['left'], targets['right']
            left_prediction_per_img = left_predictions[i].resize(leftimg.size)
            right_prediction_per_img = right_predictions[i].resize(
                leftimg.size)
            calib = leftanno.get_field('calib')
            if len(leftanno) == 0 or len(left_prediction_per_img) == 0:
                continue
            masks_per_img = masker([left_prediction_per_img.get_field('mask')],
                                   [left_prediction_per_img])[0].squeeze(1)
            disparity_per_img = leftanno.get_map('disparity')
            assert len(left_prediction_per_img.bbox) == len(
                right_prediction_per_img.bbox) == len(masks_per_img)
            for j, (left_bbox, right_bbox, mask) in enumerate(
                    zip(left_prediction_per_img.bbox,
                        right_prediction_per_img.bbox, masks_per_img)):
                x1, y1, x2, y2 = expand_box_to_integer(left_bbox.tolist())
                x1p, _, x2p, _ = expand_box_to_integer(right_bbox.tolist())
                max_width = max(x2 - x1, x2p - x1p)
                max_width = min(max_width, leftimg.width - x1)
                roi_mask = mask[y1:y2, x1:x1 + max_width]
                roi_mask = SegmentationMask(
                    roi_mask, (roi_mask.shape[1], roi_mask.shape[0]),
                    mode='mask')
                roi_left_img: Image.Image = leftimg.crop(
                    (x1, y1, x1 + max_width, y2))
                roi_right_img: Image.Image = rightimg.crop(
                    (x1p, y1, x1p + max_width, y2))
                roi_disparity = disparity_per_img.crop(
                    (x1, y1, x1 + max_width, y2)).data
                roi_disparity = roi_disparity - (x1 - x1p)
                assert roi_left_img.size == roi_right_img.size and \
                       roi_left_img.size == roi_disparity.shape[::-1], \
                    f'{roi_left_img.size} {roi_right_img.size} {roi_disparity.shape[::-1]}' \
                    f'{x1, x1p, max_width}'
                roi_left_img.save(
                    os.path.join(output_dir, split, 'image/left',
                                 str(wrote) + '.webp'))
                roi_right_img.save(
                    os.path.join(output_dir, split, 'image/right',
                                 str(wrote) + '.webp'))
                zarr.convenience.save(
                    os.path.join(output_dir, split, 'disparity',
                                 str(wrote) + '.zarr'), roi_disparity.numpy())
                pickle.dump(
                    {
                        'mask': roi_mask,
                        'x1': x1,
                        'y1': y1,
                        'x2': x2,
                        'y2': y2,
                        'x1p': x1p,
                        'x2p': x2p,
                        'fuxb': calib.stereo_fuxbaseline,
                        'image_width': leftimg.width,
                        'image_height': leftimg.height
                    },
                    open(
                        os.path.join(output_dir, split, 'label',
                                     str(wrote) + '.pkl'), 'wb'))
                wrote += 1
        print(f'made {wrote} pairs for {split}.')
예제 #5
0
    def process_input(self,
                      left_inputs,
                      right_inputs,
                      targets,
                      threshold=0.7,
                      padding=1):
        left_inputs, right_inputs = remove_empty_proposals(
            left_inputs, right_inputs)
        left_inputs, right_inputs = remove_too_right_proposals(
            left_inputs, right_inputs)
        depth_maps = []
        mask_pred_list = []

        matched_targets = []
        fus = []
        for left_prediction, right_prediction, target_per_image in zip(
                left_inputs, right_inputs, targets):
            if len(target_per_image) != 0:
                matched_target = self.match_targets_to_proposals(
                    left_prediction, target_per_image)
                matched_targets.append(matched_target)
            else:
                continue
            left_bbox = left_prediction.bbox
            right_bbox = right_prediction.bbox
            disparity_or_depth_preds = left_prediction.get_field('disparity')
            masks = left_prediction.get_field('mask')
            masker = Masker(threshold=threshold, padding=padding)
            mask_pred = masker([masks], [left_prediction])[0].squeeze(1)
            num_rois = len(left_bbox)

            fus.extend([
                target_per_image.get_field('calib').calib.fu
                for _ in range(num_rois)
            ])
            depth_maps_per_img = []
            # mask_preds_per_img = []
            if num_rois != 0:
                for left_roi, right_roi, disp_or_depth_roi, maskp in zip(
                        left_bbox, right_bbox, disparity_or_depth_preds,
                        mask_pred):
                    x1, y1, x2, y2 = expand_box_to_integer(left_roi.tolist())
                    x1p, _, x2p, _ = expand_box_to_integer(right_roi.tolist())
                    depth_map_per_roi = torch.zeros(
                        (left_prediction.height,
                         left_prediction.width)).cuda()
                    disp_roi = DisparityMap(disp_or_depth_roi).resize(
                        (max(x2 - x1, x2p - x1p), y2 - y1)).crop(
                            (0, 0, x2 - x1, y2 - y1)).data
                    disp_roi = disp_roi + x1 - x1p
                    depth_roi = target_per_image.get_field(
                        'calib').stereo_fuxbaseline / (disp_roi + 1e-6)
                    depth_map_per_roi[y1:y2, x1:x2] = depth_roi
                    depth_maps_per_img.append(depth_map_per_roi)
                depth_maps.append(depth_maps_per_img)
                mask_pred_list.append(mask_pred.cuda())
        depth_full_image = [torch.stack(d) for d in depth_maps]
        mask_pred_all = mask_pred_list
        pts = self.back_project(depth_full_image, mask_pred_all, targets)
        fus = torch.tensor(fus).cuda()
        gt_box3d_xyzhwlry = torch.cat([
            t.get_field('box3d').convert('xyzhwl_ry').bbox_3d.view(-1, 7)
            for t in matched_targets
        ])
        # aug
        # scale
        if not self.cfg.RPN.FIXED:
            scale = np.random.uniform(0.95, 1.05)
            pts = pts * scale

            gt_box3d_xyzhwlry[:, 0:6] = gt_box3d_xyzhwlry[:, 0:6] * scale
        # flip
        if not self.cfg.RPN.FIXED:
            do_flip = np.random.random() < 0.5
        else:
            do_flip = False
        if do_flip:
            pts[:, :, 0] = -pts[:, :, 0]
            gt_box3d_xyzhwlry[:, 0] = -gt_box3d_xyzhwlry[:, 0]
            gt_box3d_xyzhwlry[:, 6] = torch.sign(
                gt_box3d_xyzhwlry[:, 6]) * np.pi - gt_box3d_xyzhwlry[:, 6]
            # rotate
            self.rotator = rotate_pc_along_y(left_inputs, fus)
            self.rotator.rot_angle *= -1
        else:
            # rotate
            self.rotator = rotate_pc_along_y(left_inputs, fus)

        gt_box3d_xyzhwlry_batch_splited = torch.split(
            gt_box3d_xyzhwlry, [len(b) for b in matched_targets])
        for i in range(len(matched_targets)):
            matched_targets[i].extra_fields['box3d'] = matched_targets[
                i].extra_fields['box3d'].convert('xyzhwl_ry')
            matched_targets[i].extra_fields[
                'box3d'].bbox_3d = gt_box3d_xyzhwlry_batch_splited[i]
        # rotate
        pts = self.rotator.__call__(pts.permute(0, 2, 1)).permute(0, 2, 1)
        target_corners = self.rotator.__call__(
            torch.cat([
                t.get_field('box3d').convert('corners').bbox_3d.view(
                    -1, 8, 3).permute(0, 2, 1) for t in matched_targets
            ])).permute(0, 2, 1)
        # translate
        pts_mean = pts.mean(1)
        self.pts_mean = pts_mean
        pts = pts - pts_mean[:, None, :]
        target_corners = target_corners - pts_mean[:, None, :]
        target_corners_splited = torch.split(target_corners,
                                             [len(b) for b in matched_targets])
        for i in range(len(matched_targets)):
            matched_targets[i].extra_fields['box3d'] = matched_targets[
                i].extra_fields['box3d'].convert('corners')
            matched_targets[i].extra_fields[
                'box3d'].bbox_3d = target_corners_splited[i].contiguous().view(
                    -1, 24)

        cls_label, reg_label = generate_rpn_training_labels(
            pts, matched_targets)
        return pts, cls_label, reg_label, matched_targets
예제 #6
0
    def prepare_psmnet_input_and_target(self,
                                        left_images: ImageList,
                                        right_images: ImageList,
                                        left_result: List[BoxList],
                                        right_result: List[BoxList],
                                        left_targets: List[BoxList],
                                        require_mask_tgts=True):
        if require_mask_tgts:
            roi_disp_targets = []
            roi_masks = []
            ims_per_batch = len(left_result)
            rois_for_image_crop_left = []
            rois_for_image_crop_right = []
            for i in range(ims_per_batch):
                left_target = left_targets[i]
                mask_gt_per_img = left_target.get_field(
                    'masks').get_full_image_mask_tensor().byte()
                disparity_map_per_img: DisparityMap = left_target.get_map(
                    'disparity')

                mask_preds_per_img = self.masker(
                    [left_result[i].get_field('mask')],
                    [left_result[i]])[0].squeeze(1).byte()
                if mask_preds_per_img.ndimension() == 2:
                    mask_preds_per_img = mask_preds_per_img.unsqueeze(0)
                for j, (leftbox, rightbox, mask_pred) in enumerate(
                        zip(left_result[i].bbox.tolist(),
                            right_result[i].bbox.tolist(),
                            mask_preds_per_img)):
                    # 1 align left box and right box
                    x1, y1, x2, y2 = expand_box_to_integer(leftbox)
                    x1p, _, x2p, _ = expand_box_to_integer(rightbox)
                    x1 = max(0, x1)
                    x1p = max(0, x1p)
                    y1 = max(0, y1)
                    y2 = min(y2, left_result[i].height - 1)
                    x2 = min(x2, left_result[i].width - 1)
                    x2p = min(x2p, left_result[i].width - 1)
                    max_width = max(x2 - x1, x2p - x1p)
                    allow_extend_width = min(left_result[i].width - x1,
                                             left_result[i].width - x1p)
                    max_width = min(max_width, allow_extend_width)
                    rois_for_image_crop_left.append(
                        [i, x1, y1, x1 + max_width, y2])
                    rois_for_image_crop_right.append(
                        [i, x1p, y1, x1p + max_width, y2])
                    # prepare target
                    roi_disparity_map = disparity_map_per_img.crop(
                        (x1, y1, x1 + max_width, y2))
                    roi_disparity_map.data = roi_disparity_map.data - (x1 -
                                                                       x1p)
                    roi_disp_target = roi_disparity_map.resize(
                        (self.disp_resolution, self.disp_resolution)).data
                    mask_pred = mask_pred & mask_gt_per_img
                    roi_mask = mask_pred[y1:y2, x1:x1 + max_width]
                    roi_mask = interpolate(
                        roi_mask[None, None].float(),
                        (self.disp_resolution, self.disp_resolution),
                        mode='bilinear',
                        align_corners=True)[0, 0].byte()
                    roi_disp_targets.append(roi_disp_target)
                    roi_masks.append(roi_mask)
            # crop and resize images
            left_roi_images = self.crop_and_transform_roi_img(
                left_images.tensors, rois_for_image_crop_left)
            right_roi_images = self.crop_and_transform_roi_img(
                right_images.tensors, rois_for_image_crop_right)
            if len(left_roi_images) != 0:
                roi_disp_targets = torch.stack(roi_disp_targets)
                roi_masks = torch.stack(roi_masks).cuda()
            else:
                left_roi_images = torch.empty(
                    (0, 3, self.disp_resolution, self.disp_resolution)).cuda()
                right_roi_images = torch.empty(
                    (0, 3, self.disp_resolution, self.disp_resolution)).cuda()
                roi_disp_targets = torch.empty(
                    (0, self.disp_resolution, self.disp_resolution)).cuda()
                roi_masks = torch.empty(
                    (0, self.disp_resolution, self.disp_resolution)).cuda()
            return left_roi_images, right_roi_images, roi_disp_targets, roi_masks
        else:
            ims_per_batch = len(left_result)
            rois_for_image_crop_left = []
            rois_for_image_crop_right = []
            fxus, x1s, x1ps, x2s, x2ps = [], [], [], [], []
            for i in range(ims_per_batch):
                left_target = left_targets[i]
                calib = left_target.get_field('calib')
                fxus.extend([
                    calib.stereo_fuxbaseline
                    for _ in range(len(left_result[i]))
                ])
                mask_preds_per_img = self.masker(
                    [left_result[i].get_field('mask')],
                    [left_result[i]])[0].squeeze(1).byte()
                if mask_preds_per_img.ndimension() == 2:
                    mask_preds_per_img = mask_preds_per_img.unsqueeze(0)
                for j, (leftbox, rightbox, mask_pred) in enumerate(
                        zip(left_result[i].bbox.tolist(),
                            right_result[i].bbox.tolist(),
                            mask_preds_per_img)):
                    # 1 align left box and right box
                    x1, y1, x2, y2 = expand_box_to_integer(leftbox)
                    x1p, _, x2p, _ = expand_box_to_integer(rightbox)
                    x1 = max(0, x1)
                    x1p = max(0, x1p)
                    y1 = max(0, y1)
                    y2 = min(y2, left_result[i].height - 1)
                    x2 = min(x2, left_result[i].width - 1)
                    x2p = min(x2p, left_result[i].width - 1)
                    max_width = max(x2 - x1, x2p - x1p)
                    allow_extend_width = min(left_result[i].width - x1,
                                             left_result[i].width - x1p)
                    max_width = min(max_width, allow_extend_width)
                    rois_for_image_crop_left.append(
                        [i, x1, y1, x1 + max_width, y2])
                    rois_for_image_crop_right.append(
                        [i, x1p, y1, x1p + max_width, y2])
                    x1s.append(x1)
                    x1ps.append(x1p)
                    x2s.append(x1 + max_width)
                    x2ps.append(x1p + max_width)
            # crop and resize images
            left_roi_images = self.crop_and_transform_roi_img(
                left_images.tensors, rois_for_image_crop_left)
            right_roi_images = self.crop_and_transform_roi_img(
                right_images.tensors, rois_for_image_crop_right)
            if len(left_roi_images) != 0:
                x1s = torch.tensor(x1s).cuda()
                x1ps = torch.tensor(x1ps).cuda()
                x2s = torch.tensor(x2s).cuda()
                x2ps = torch.tensor(x2ps).cuda()
                fxus = torch.tensor(fxus).cuda()
            else:
                left_roi_images = torch.empty(
                    (0, 3, self.disp_resolution, self.disp_resolution)).cuda()
                right_roi_images = torch.empty(
                    (0, 3, self.disp_resolution, self.disp_resolution)).cuda()
            return left_roi_images, right_roi_images, fxus, x1s, x1ps, x2s, x2ps
예제 #7
0
def main():
    args = parser.parse_args()
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    root = 'data/kitti'
    roi_align = ROIAlign((224, 224), 1.0, 0)
    if args.splits == 'trainval':
        splits = ['train', 'val']
    else:
        splits = [args.splits]
    masker = Masker(args.masker_thresh)
    for split in splits:
        prediction_pth = args.prediction_template % split
        predictions = torch.load(prediction_pth)
        left_predictions, right_predictions = predictions['left'], predictions[
            'right']
        os.makedirs(os.path.join(output_dir, split, 'image', 'left'),
                    exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'image', 'right'),
                    exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'label'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, 'disparity'),
                    exist_ok=True)
        if args.cls == 'car':
            ds = KITTIObjectDatasetCar(root,
                                       split,
                                       filter_empty=False,
                                       shape_prior_base=args.shape_prior_base)
        elif args.cls == 'pedestrian':
            ds = KITTIObjectDatasetPedestrian(
                root,
                split,
                filter_empty=False,
                shape_prior_base=args.shape_prior_base)
        else:  # cyclist
            ds = KITTIObjectDatasetCyclist(root,
                                           split,
                                           filter_empty=False,
                                           shape_prior_base='notused')

        wrote = 0
        assert len(left_predictions) == len(ds)
        for i, (images, targets, _) in enumerate(tqdm(ds)):
            leftimg, rightimg = images['left'], images['right']
            leftanno, rightanno = targets['left'], targets['right']
            left_prediction_per_img = left_predictions[i].resize(leftimg.size)
            right_prediction_per_img = right_predictions[i].resize(
                leftimg.size)

            calib = leftanno.get_field('calib')
            if len(leftanno) == 0 or len(left_prediction_per_img) == 0:
                continue
            imgid: int = leftanno.get_field('imgid')[0, 0].item()
            # os.makedirs(osp.join(output_dir, split, 'imgid_org_left', str(imgid)), exist_ok=True)
            masks_per_img = masker([left_prediction_per_img.get_field('mask')],
                                   [left_prediction_per_img])[0].squeeze(1)
            disparity_per_img = leftanno.get_map('disparity')
            assert len(left_prediction_per_img.bbox) == len(
                right_prediction_per_img.bbox) == len(masks_per_img)
            rois_for_image_crop_left = []
            rois_for_image_crop_right = []
            fxus, x1s, x1ps, x2s, x2ps, y1s, y2s = [], [], [], [], [], [], []
            roi_masks = []
            roi_disps = []
            for j, (left_bbox, right_bbox, mask) in enumerate(
                    zip(left_prediction_per_img.bbox,
                        right_prediction_per_img.bbox, masks_per_img)):
                x1, y1, x2, y2 = expand_box_to_integer(left_bbox.tolist())
                x1p, _, x2p, _ = expand_box_to_integer(right_bbox.tolist())
                max_width = max(x2 - x1, x2p - x1p)
                max_width = min(max_width, leftimg.width - x1)
                allow_extend_width = min(left_prediction_per_img.width - x1,
                                         left_prediction_per_img.width - x1p)
                max_width = min(max_width, allow_extend_width)
                rois_for_image_crop_left.append(
                    [0, x1, y1, x1 + max_width, y2])
                rois_for_image_crop_right.append(
                    [0, x1p, y1, x1p + max_width, y2])
                x1s.append(x1)
                x1ps.append(x1p)
                x2s.append(x1 + max_width)
                x2ps.append(x1p + max_width)
                y1s.append(y1)
                y2s.append(y2)

                roi_mask = mask[y1:y2, x1:x1 + max_width]
                roi_mask = SegmentationMask(
                    roi_mask, (roi_mask.shape[1], roi_mask.shape[0]),
                    mode='mask')
                roi_mask = roi_mask.resize((224, 224))
                # roi_masks.append(roi_mask)
                roi_disparity = disparity_per_img.crop(
                    (x1, y1, x1 + max_width, y2)).data
                dispfg_mask = SegmentationMask(
                    roi_disparity != 0,
                    (roi_disparity.shape[1], roi_disparity.shape[0]),
                    mode='mask').resize((224, 224)).get_mask_tensor()

                roi_disparity = roi_disparity - (x1 - x1p)
                roi_disparity = DisparityMap(roi_disparity).resize(
                    (224, 224)).data
                # pdb.set_trace()
                roi_masks.append(roi_mask)
                roi_disps.append(roi_disparity)
            # crop and resize image
            leftimg = F.to_tensor(leftimg).unsqueeze(0)
            rightimg = F.to_tensor(rightimg).unsqueeze(0)
            rois_for_image_crop_left = torch.as_tensor(
                rois_for_image_crop_left).float()
            rois_for_image_crop_right = torch.as_tensor(
                rois_for_image_crop_right).float()
            roi_left_imgs = roi_align(leftimg, rois_for_image_crop_left)
            roi_right_imgs = roi_align(rightimg, rois_for_image_crop_right)
            for j in range(len(roi_left_imgs)):
                zarr.save(
                    osp.join(output_dir, split, 'image/left',
                             str(wrote) + '.zarr'), roi_left_imgs[j].numpy())
                zarr.save(
                    osp.join(output_dir, split, 'image/right',
                             str(wrote) + '.zarr'), roi_right_imgs[j].numpy())
                zarr.save(
                    osp.join(output_dir, split, 'disparity',
                             str(wrote) + '.zarr'), roi_disps[j].numpy())
                out_path = os.path.join(output_dir, split, 'label',
                                        str(wrote) + '.pkl')
                pickle.dump(
                    {
                        'mask': roi_masks[j],
                        'x1': x1s[j],
                        'y1': y1s[j],
                        'x2': x2s[j],
                        'y2': y2s[j],
                        'x1p': x1ps[j],
                        'x2p': x2ps[j],
                        'fuxb': calib.stereo_fuxbaseline,
                        'imgid': imgid
                    }, open(out_path, 'wb'))
                wrote += 1
        print(f'made {wrote} pairs for {split}.')