示例#1
0
def eval_det_multiprocessing(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou):
    """ Generic functions to compute precision/recall for object detection
        for multiple classes.
        Input:
            pred_all: map of {img_id: [(classname, bbox, score)]}
            gt_all: map of {img_id: [(classname, bbox)]}
            ovthresh: scalar, iou threshold
            use_07_metric: bool, if true use VOC07 11 point method
        Output:
            rec: {classname: rec}
            prec: {classname: prec_all}
            ap: {classname: scalar}
    """
    pred = {} # map {classname: pred}
    gt = {} # map {classname: gt}
    for img_id in pred_all.keys():
        for classname, bbox, score in pred_all[img_id]:
            if classname not in pred: pred[classname] = {}
            if img_id not in pred[classname]:
                pred[classname][img_id] = []
            if classname not in gt: gt[classname] = {}
            if img_id not in gt[classname]:
                gt[classname][img_id] = []
            pred[classname][img_id].append((bbox,score))
    for img_id in gt_all.keys():
        for classname, bbox in gt_all[img_id]:
            if classname not in gt: gt[classname] = {}
            if img_id not in gt[classname]:
                gt[classname][img_id] = []
            gt[classname][img_id].append(bbox)

    rec = {}
    prec = {}
    ap = {}
    p = Pool(processes=10)
    class2type_map = ScannetDatasetConfig().class2type
    ret_values = p.map(eval_det_cls_wrapper, [(pred[classname], gt[classname], class2type_map[classname], ovthresh, use_07_metric, get_iou_func) for classname in gt.keys() if classname in pred])
    p.close()
    for i, classname in enumerate(gt.keys()):
        if classname in pred:
            rec[classname], prec[classname], ap[classname] = ret_values[i]
        else:
            rec[classname] = 0
            prec[classname] = 0
            ap[classname] = 0

    return rec, prec, ap
示例#2
0
def get_loader(args):
    # Init datasets and dataloaders
    def my_worker_init_fn(worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    # Create Dataset and Dataloader
    if args.dataset == 'sunrgbd':
        from sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset
        from sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig

        DATASET_CONFIG = SunrgbdDatasetConfig()
        TEST_DATASET = SunrgbdDetectionVotesDataset('val', num_points=args.num_point,
                                                    augment=False,
                                                    use_color=True if args.use_color else False,
                                                    use_height=True if args.use_height else False,
                                                    use_v1=(not args.use_sunrgbd_v2),
                                                    data_root=args.data_root)
    elif args.dataset == 'scannet':
        sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
        from scannet.scannet_detection_dataset import ScannetDetectionDataset
        from scannet.model_util_scannet import ScannetDatasetConfig

        DATASET_CONFIG = ScannetDatasetConfig()
        TEST_DATASET = ScannetDetectionDataset('val', num_points=args.num_point,
                                               augment=False,
                                               use_color=True if args.use_color else False,
                                               use_height=True if args.use_height else False,
                                               data_root=args.data_root)
    else:
        raise NotImplementedError(f'Unknown dataset {args.dataset}. Exiting...')

    logger.info(str(len(TEST_DATASET)))

    TEST_DATALOADER = DataLoader(TEST_DATASET, batch_size=args.batch_size * torch.cuda.device_count(),
                                 shuffle=args.shuffle_dataset,
                                 num_workers=4,
                                 worker_init_fn=my_worker_init_fn)
    return TEST_DATALOADER, DATASET_CONFIG
示例#3
0
def gen_scannet_split(labeled_ratio, count):
    DC = ScannetDatasetConfig()
    split_set = 'train'
    split_filenames = os.path.join(ROOT_DIR, 'scannet/meta_data',
                                   'scannetv2_{}.txt'.format(split_set))
    with open(split_filenames, 'r') as f:
        scan_names = f.read().splitlines()
        # remove unavailiable scans
    num_scans = len(scan_names)
    scan2label = np.zeros((num_scans, DC.num_class))
    num_labeled_scans = int(labeled_ratio * num_scans)
    data_path = os.path.join(BASE_DIR, 'scannet/scannet_train_detection_data')
    for i, scan_name in enumerate(scan_names):
        instance_bboxes = np.load(os.path.join(data_path, scan_name) + '_bbox.npy')
        class_ind = [DC.nyu40id2class[x] for x in instance_bboxes[:, -1]]
        if class_ind != []:
            unique_class_ind = list(set(class_ind))
        else:
            continue
        for j in unique_class_ind:
            scan2label[i, j] = 1

    while True:
        choices = np.random.choice(num_scans, num_labeled_scans, replace=False)
        class_distr = np.sum(scan2label[choices], axis=0)
        class_mask = np.where(class_distr > 0, 1, 0)
        if np.sum(class_mask) == DC.num_class:
            labeled_scan_names = list(np.array(scan_names)[choices])
            with open(os.path.join(ROOT_DIR, 'scannet/meta_data/scannetv2_train_{}_{}.txt'.format(labeled_ratio, count)),
                      'w') as f:
                for scan_name in labeled_scan_names:
                    f.write(scan_name + '\n')
            break

    unlabeled_scan_names = list(set(scan_names) - set(labeled_scan_names))
    print('\tSelected {} labeled scans, remained {} unlabeled scans'.format(len(labeled_scan_names), len(unlabeled_scan_names)))
示例#4
0
def get_loader(args):
    # Init datasets and dataloaders
    def my_worker_init_fn(worker_id):
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    # Create Dataset and Dataloader
    if args.dataset == 'sunrgbd':
        from sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset
        from sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig

        DATASET_CONFIG = SunrgbdDatasetConfig()
        TRAIN_DATASET = SunrgbdDetectionVotesDataset(
            'train',
            num_points=args.num_point,
            augment=True,
            use_color=True if args.use_color else False,
            use_height=True if args.use_height else False,
            use_v1=(not args.use_sunrgbd_v2),
            data_root=args.data_root)
        TEST_DATASET = SunrgbdDetectionVotesDataset(
            'val',
            num_points=args.num_point,
            augment=False,
            use_color=True if args.use_color else False,
            use_height=True if args.use_height else False,
            use_v1=(not args.use_sunrgbd_v2),
            data_root=args.data_root)
    elif args.dataset == 'scannet':
        sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
        from scannet.scannet_detection_dataset import ScannetDetectionDataset
        from scannet.model_util_scannet import ScannetDatasetConfig

        DATASET_CONFIG = ScannetDatasetConfig()
        TRAIN_DATASET = ScannetDetectionDataset(
            'train',
            num_points=args.num_point,
            augment=True,
            use_color=True if args.use_color else False,
            use_height=True if args.use_height else False,
            data_root=args.data_root)
        TEST_DATASET = ScannetDetectionDataset(
            'val',
            num_points=args.num_point,
            augment=False,
            use_color=True if args.use_color else False,
            use_height=True if args.use_height else False,
            data_root=args.data_root)
    else:
        raise NotImplementedError(
            f'Unknown dataset {args.dataset}. Exiting...')

    print(f"train_len: {len(TRAIN_DATASET)}, test_len: {len(TEST_DATASET)}")

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        TRAIN_DATASET)
    train_loader = torch.utils.data.DataLoader(
        TRAIN_DATASET,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        worker_init_fn=my_worker_init_fn,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)

    test_sampler = torch.utils.data.distributed.DistributedSampler(
        TEST_DATASET, shuffle=False)
    test_loader = torch.utils.data.DataLoader(TEST_DATASET,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              worker_init_fn=my_worker_init_fn,
                                              pin_memory=True,
                                              sampler=test_sampler,
                                              drop_last=False)
    print(
        f"train_loader_len: {len(train_loader)}, test_loader_len: {len(test_loader)}"
    )

    return train_loader, test_loader, DATASET_CONFIG
示例#5
0
        use_height=(not FLAGS.no_height),
        use_v1=(not FLAGS.use_sunrgbd_v2),
        load_labels=FLAGS.view_stats)
    TEST_DATASET = SunrgbdDetectionVotesDataset(
        'val',
        num_points=NUM_POINT,
        augment=False,
        use_color=FLAGS.use_color,
        use_height=(not FLAGS.no_height),
        use_v1=(not FLAGS.use_sunrgbd_v2))
elif FLAGS.dataset == 'scannet':
    sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
    from scannet.scannet_detection_dataset import ScannetDetectionDataset
    from scannet.scannet_ssl_dataset import ScannetSSLLabeledDataset, ScannetSSLUnlabeledDataset
    from scannet.model_util_scannet import ScannetDatasetConfig
    DATASET_CONFIG = ScannetDatasetConfig()
    LABELED_DATASET = ScannetSSLLabeledDataset(
        labeled_sample_list=FLAGS.labeled_sample_list,
        num_points=NUM_POINT,
        augment=True,
        use_color=FLAGS.use_color,
        use_height=(not FLAGS.no_height))
    UNLABELED_DATASET = ScannetSSLUnlabeledDataset(
        labeled_sample_list=FLAGS.labeled_sample_list,
        num_points=NUM_POINT,
        use_color=FLAGS.use_color,
        use_height=(not FLAGS.no_height),
        load_labels=FLAGS.view_stats)
    TEST_DATASET = ScannetDetectionDataset('val',
                                           num_points=NUM_POINT,
                                           augment=False,
示例#6
0
Modified by Yezhen Cong, 2020
"""
import os
import sys
import random
import numpy as np
from torch.utils.data import Dataset

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
import utils.pc_util as pc_util
from scannet.model_util_scannet import ScannetDatasetConfig, rotate_aligned_boxes

DC = ScannetDatasetConfig()
MAX_NUM_OBJ = 64
MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8])


class ScannetSSLLabeledDataset(Dataset):
    def __init__(self,
                 labeled_sample_list=None,
                 num_points=20000,
                 use_color=False,
                 use_height=False,
                 augment=False):

        print('--------- Scannet Labeled Dataset Initialization ---------')
        self.data_path = os.path.join(BASE_DIR, 'scannet_train_detection_data')
        if labeled_sample_list is not None: