def visualize_proposals(self,
                            imgs,
                            proposals,
                            gt_bboxes,
                            img_meta,
                            slice_num,
                            isProposal=True):
        if slice_num is None:
            img = tensor2img3D(imgs, slice_num=45)
        else:
            img = tensor2img3D(imgs, slice_num=slice_num)

        batch_num = 0
        for cur_proposals, cur_gt_bboxes, cur_img_meta in zip(
                proposals, gt_bboxes, img_meta):
            bboxes = []
            cur_proposals = cur_proposals.cpu().numpy()
            for bbox in cur_proposals:
                if slice_num is None:
                    bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3]])
                elif slice_num is not None and slice_num >= math.floor(
                        bbox[4]) and slice_num <= math.ceil(bbox[5]):
                    # select bounding boxes on this slice
                    bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3]])
            part_filename = 'prop' if isProposal else 'anch'
            filename = 'tests/iter_{}_img_id_{}_batch_{}_{}.png'.format(
                self.iteration, cur_img_meta['image_id'], batch_num,
                part_filename)
            self.show_bboxes_gt_bboxes(img,
                                       np.array(bboxes),
                                       gt_bboxes=cur_gt_bboxes,
                                       out_file=filename)
            batch_num += 1
 def visualize_gt_bboxes_masks(self, imgs, gt_bboxes, img_meta, gt_masks):
     for num_img in range(len(gt_bboxes)):
         gt_bboxes_np = gt_bboxes[num_img].cpu().numpy()
         gt_masks_np = gt_masks[num_img].cpu().numpy()
         bbox_num = 0
         for bbox in gt_bboxes_np:
             for slice_num in range(int(bbox[4]), int(bbox[5])):
                 img = tensor2img3D(imgs, slice_num=slice_num)
                 filename = 'tests/iter_{}_img_id_{}_bbox_num_{}_slice_{}.png'.format(
                     self.iteration, img_meta[num_img]['image_id'],
                     bbox_num, slice_num)
                 plt.figure()
                 plt.imshow(img)
                 plt.imshow(gt_masks_np[0, slice_num, :, :] * 255,
                            alpha=0.3)
                 ax = plt.gca()
                 rect = pts.Rectangle((bbox[0], bbox[1]),
                                      bbox[2] - bbox[0],
                                      bbox[3] - bbox[1],
                                      fill=False,
                                      edgecolor='red',
                                      linewidth=2)
                 ax.add_patch(rect)
                 plt.savefig(filename)
                 plt.close()
             bbox_num += 1
     breakpoint()
    def visualize_gt_bboxes(self, imgs, gt_bboxes, img_meta):
        gt_bboxes_np = gt_bboxes[0].cpu().numpy()

        for bbox in gt_bboxes_np:
            for slice_num in range(int(bbox[4]), int(bbox[5])):
                img = tensor2img3D(imgs, slice_num=slice_num)
                filename = 'tests/iter_{}_img_id_{}_slice_{}.png'.format(self.iteration, img_meta[0]['image_id'], slice_num)
                mmcv.imshow_bboxes(img, np.array([bbox]), show=False, out_file=filename)
    def visualize_anchor_boxes(self,
                               imgs,
                               cls_scores,
                               img_metas,
                               slice_num=45,
                               top_k=None,
                               shuffle=False):
        featmap_sizes = [featmap.size()[-3:] for featmap in cls_scores]
        assert len(featmap_sizes) == len(self.anchor_generators)
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas)
        img = tensor2img3D(imgs, slice_num=slice_num)
        anchors = []
        unique_set = set()
        for bboxes in anchor_list[0]:
            bboxes = bboxes.cpu().numpy()
            for bbox in bboxes:
                # select each aspect ratio bounding box in the middle of an image
                # if bbox[0] >= 100 and bbox[0] <= 400 and bbox[2] >= 100 and bbox[2] <= 400 and \
                #     bbox[1] >= 150 and bbox[1] <= 450 and bbox[3] >= 150 and bbox[3] <= 450 and \
                #     slice_num >= bbox[4] and slice_num <= bbox[5] and (bbox[5] - bbox[4]) not in unique_set:
                #     unique_set.add(bbox[5] - bbox[4])
                #     anchors.append([bbox[0], bbox[1], bbox[2], bbox[3]])

                # Get all anchors in the middle of the image
                if bbox[0] >= 100 and bbox[0] <= 400 and bbox[2] >= 100 and bbox[2] <= 400 and \
                    bbox[1] >= 150 and bbox[1] <= 450 and bbox[3] >= 150 and bbox[3] <= 450 and \
                    slice_num >= bbox[4] and slice_num <= bbox[5] and (bbox[2] - bbox[0]) not in unique_set:
                    anchors.append([bbox[0], bbox[1], bbox[2], bbox[3]])
        print(unique_set)
        breakpoint()
        if shuffle is True:
            while True:
                random.shuffle(anchors)
                mmcv.imshow_bboxes(img, np.array(anchors), top_k=20)
        elif top_k is None:
            mmcv.imshow_bboxes(img, np.array(anchors))
        else:
            mmcv.imshow_bboxes(img, np.array(anchors), top_k=top_k)
        breakpoint()