예제 #1
0
def collect_wh_data(cfg, args, stop_count):
    # stop_count 防止数据太大,要很久才能跑完

    dataset = build_dataset(cfg.data.train)  # 这样才能考虑到数据增强带来的图片比例改变
    dataloader = build_dataloader(dataset, args.samples_per_gpu,
                                  args.workers_per_gpu)
    print('----开始遍历数据集----')
    wh_all = []
    for count in range(args.repeat_count):
        progress_bar = cv_core.ProgressBar(len(dataloader))
        for i, data_batch in enumerate(dataloader):
            if i > stop_count:
                break
            gt_bboxes = data_batch['gt_bboxes'].data[0]
            gt_bboxes = torch.cat(gt_bboxes, dim=0).numpy()
            if len(gt_bboxes) == 0:
                continue
            w = (gt_bboxes[:, 2] - gt_bboxes[:, 0])
            h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
            wh = np.stack((w, h), axis=1)
            wh_all.append(wh)
            progress_bar.update()
    wh_all = np.concatenate(wh_all, axis=0)
    print(wh_all.shape)
    return wh_all
def main():
    args = parse_args()
    cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)

    dataset = build_dataset(cfg.data.train)

    progress_bar = mmcv.ProgressBar(len(dataset))

    for item in dataset:
        filename = os.path.join(args.output_dir,
                                Path(item['filename']).name
                                ) if args.output_dir is not None else None

        gt_masks = item.get('gt_masks', None)
        if gt_masks is not None:
            gt_masks = mask2ndarray(gt_masks)

        imshow_det_bboxes(item['img'],
                          item['gt_bboxes'],
                          item['gt_labels'],
                          gt_masks,
                          class_names=dataset.CLASSES,
                          show=not args.not_show,
                          wait_time=args.show_interval,
                          out_file=filename,
                          bbox_color=(255, 102, 61),
                          text_color=(255, 102, 61))

        progress_bar.update()
예제 #3
0
def show_featuremap_from_datalayer(featurevis, feature_indexs, is_show, output_dir):
    if not isinstance(feature_indexs, (list, tuple)):
        feature_indexs = [feature_indexs]
    dataset = build_dataset(cfg.data.test)
    for item in dataset:
        img_tensor = item['img']
        img_metas = item['img_metas'][0].data
        filename = img_metas['filename']
        img_norm_cfg = img_metas['img_norm_cfg']
        img = img_tensor[0].cpu().numpy().transpose(1, 2, 0)  # 依然是归一化后的图片
        img_orig = imdenormalize(img, img_norm_cfg['mean'], img_norm_cfg['std']).astype(np.uint8)
        _show_save_data(featurevis, img, img_orig, feature_indexs, filename, is_show, output_dir)
예제 #4
0
def main():
    args = parse_args()
    cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)

    if 'gt_semantic_seg' in cfg.train_pipeline[-1]['keys']:
        cfg.data.train.pipeline = [
            p for p in cfg.data.train.pipeline if p['type'] != 'SegRescale'
        ]
    dataset = build_dataset(cfg.data.train)

    progress_bar = mmcv.ProgressBar(len(dataset))

    for item in dataset:
        filename = os.path.join(args.output_dir,
                                Path(item['filename']).name
                                ) if args.output_dir is not None else None

        gt_bboxes = item['gt_bboxes']
        gt_labels = item['gt_labels']
        gt_masks = item.get('gt_masks', None)
        if gt_masks is not None:
            gt_masks = mask2ndarray(gt_masks)

        gt_seg = item.get('gt_semantic_seg', None)
        if gt_seg is not None:
            pad_value = 255  # the padding value of gt_seg
            sem_labels = np.unique(gt_seg)
            all_labels = np.concatenate((gt_labels, sem_labels), axis=0)
            all_labels, counts = np.unique(all_labels, return_counts=True)
            stuff_labels = all_labels[np.logical_and(counts < 2,
                                                     all_labels != pad_value)]
            stuff_masks = gt_seg[None] == stuff_labels[:, None, None]
            gt_labels = np.concatenate((gt_labels, stuff_labels), axis=0)
            gt_masks = np.concatenate((gt_masks, stuff_masks.astype(np.uint8)),
                                      axis=0)
            # If you need to show the bounding boxes,
            # please comment the following line
            gt_bboxes = None

        imshow_det_bboxes(item['img'],
                          gt_bboxes,
                          gt_labels,
                          gt_masks,
                          class_names=dataset.CLASSES,
                          show=not args.not_show,
                          wait_time=args.show_interval,
                          out_file=filename,
                          bbox_color=dataset.PALETTE,
                          text_color=(200, 200, 200),
                          mask_color=dataset.PALETTE)

        progress_bar.update()
예제 #5
0
    def __init__(self, dataset_cfgs, dataset_sampling_weights=None):
        if dataset_sampling_weights is None:
            self.dataset_sampling_probs = [1. / len(dataset_cfgs)
                                           ] * len(dataset_cfgs)
        else:
            for x in dataset_sampling_weights:
                assert x >= 0.
            prob_total = float(sum(dataset_sampling_weights))
            assert prob_total > 0.
            self.dataset_sampling_probs = [
                x / prob_total for x in dataset_sampling_weights
            ]

        datasets = [build_dataset(cfg) for cfg in dataset_cfgs]
        # add an attribute `CLASSES` for the calling in `tools/train.py`
        self.CLASSES = datasets[0].CLASSES

        super().__init__(datasets)
예제 #6
0
def main():
    args = parse_args()
    cfg = retrieve_data_cfg(args.config, args.skip_type)

    dataset = build_dataset(cfg.data.train)

    progress_bar = cv_core.ProgressBar(len(dataset))
    for item in dataset:
        filename = os.path.join(args.output_dir,
                                Path(item['filename']).name
                                ) if args.output_dir is not None else None
        cv_core.imshow_det_bboxes(item['img'],
                                  item['gt_bboxes'],
                                  item['gt_labels'],
                                  class_names=dataset.CLASSES,
                                  show=not args.not_show,
                                  out_file=filename,
                                  wait_time=args.show_interval)
        progress_bar.update()
def main():
    args = parse_args()
    os.makedirs(args.output_root, exist_ok=True)
    cfg = retrieve_data_cfg(args.config_path, args.fold, args.skip_type)

    dataset: WheatDataset = build_dataset(cfg.data.train)
    from IPython import embed

    embed()
    heights = []
    widths = []
    for i, data in tqdm(enumerate(dataset), total=len(dataset)):
        image_id = osp.basename(dataset.data_infos[i]["file_name"])
        image = data["img"]
        bboxes = data["gt_bboxes"]
        ignore_bboxes = data["gt_bboxes_ignore"]
        if len(bboxes) == 0:
            print(image.shape)
        widths.append(bboxes[:, 2] - bboxes[:, 0])
        heights.append(bboxes[:, 3] - bboxes[:, 1])
        draw_bounding_boxes_on_image(image,
                                     bboxes,
                                     use_normalized_coordinates=False,
                                     thickness=5)
        if len(ignore_bboxes):
            draw_bounding_boxes_on_image(
                image,
                ignore_bboxes,
                label2colors={None: {
                    "bbox": (0, 255, 0)
                }},
                use_normalized_coordinates=False,
                thickness=5,
            )
        cv2.imwrite(osp.join(args.output_root, f"{i}_{image_id}"), image)
    widths = np.concatenate(widths)
    heights = np.concatenate(heights)
    clusters = kmeans(np.stack([heights, widths], axis=1), k=10)
    print(f"aspect rations: {clusters[:, 0] / clusters[:, 1]}")
    print(f"sizes: {np.sqrt(clusters[:, 0] * clusters[:, 1])}")
    plot_statistics(widths, heights, args.output_root)