コード例 #1
0
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        train_dataset = gdata.VOCDetection(
            splits=[('sbdche', 'train' + '_' + str(args.deg) + '_bboxwh')])
        if args.val_2012 == True:
            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val' + '_' + str(args.deg) + '_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(
            iou_thresh=0.5, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco_pretrain':
        train_dataset = gdata.coco_pretrain_Detection(
            splits=[('_coco_20', 'train' + '_' + str(args.deg) + '_bboxwh')])
        if args.val_2012 == True:
            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val' + '_' + str(args.deg) + '_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(
            iou_thresh=0.5, class_names=val_dataset.classes)
    else:
        raise NotImplementedError(
            'Dataset: {} not implemented.'.format(dataset))
    if args.num_samples < 0:
        args.num_samples = len(train_dataset)
    return train_dataset, val_dataset, val_metric, val_polygon_metric
コード例 #2
0
ファイル: coco_eval.py プロジェクト: lixiny/ese_seg
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        if args.val_voc2012:
            val_dataset = gdata.VOC_Val_Detection(
            splits=[('sbdche', 'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val'+'_'+'8'+'_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco':
コード例 #3
0
ファイル: sbd_train_che_8_1.py プロジェクト: lixiny/ese_seg
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        if args.val_2012 == True:
            train_dataset = gdata.VOCDetection(
                splits=[('sbdche', 'train_voc2012_bboxwh')])

            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val_2012_bboxwh')])
        else:
            train_dataset = gdata.VOCDetection(splits=[('sbdche',
                                                        'train' + '_' + '8' +
                                                        '_bboxwh')])
            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val' + '_' + '8' +
                                                           '_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.7,
                                    class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(
            iou_thresh=0.7, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco_pretrain':
        train_dataset = gdata.coco_pretrain_Detection(
            splits=[('_coco_20', 'train' + '_' + '8' + '_bboxwh')])
        if args.val_2012 == True:
            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(splits=[('sbdche',
                                                           'val' + '_' + '8' +
                                                           '_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.7,
                                    class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(
            iou_thresh=0.7, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco':
        train_dataset = gdata.cocoDetection(
            root='/home/tutian/dataset/coco_to_voc/train',
            subfolder='./bases_50_xml_each_' + 'var')
        val_dataset = gdata.cocoDetection(
            root='/home/tutian/dataset/coco_to_voc/val',
            subfolder='./bases_50_xml_' + 'raw_coef')
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        # val_polygon_metric = New07PolygonMApMetric(iou_thresh=0.5, class_names=val_dataset.classes, root='/home/tutian/dataset/coco_to_voc/val/')
        val_polygon_metric = None
    else:
        raise NotImplementedError(
            'Dataset: {} not implemented.'.format(dataset))
    if args.num_samples < 0:
        args.num_samples = len(train_dataset)
    if args.mixup:
        from gluoncv.data import MixupDetection
        train_dataset = MixupDetection(train_dataset)
    return train_dataset, val_dataset, val_metric, val_polygon_metric
コード例 #4
0
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        if args.val_voc2012:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val' + '_' + str(args.deg) + '_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
    else:
        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
    return val_dataset, val_metric, val_polygon_metric
コード例 #5
0
ファイル: sbd_eval_che_8.py プロジェクト: lixiny/ese_seg
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        if args.val_voc2012:
            val_dataset = gdata.VOC_Val_Detection(
            splits=[('sbdche', 'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val'+'_'+'8'+'_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco':
        val_dataset = gdata.cocoDetection(root='/home/tutian/dataset/coco_to_voc/val', subfolder='./bases_50_xml_'+'raw_coef')
        val_metric = VOC07MApMetric(iou_thresh=0.75, class_names=val_dataset.classes)
        val_polygon_metric = New07PolygonMApMetric(iou_thresh=0.75, class_names=val_dataset.classes, root='/home/tutian/dataset/coco_to_voc/val/')
    else:
        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
    return val_dataset, val_metric, val_polygon_metric
コード例 #6
0
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
        if args.val_voc2012:
            val_dataset = gdata.VOC_Val_Detection(
            splits=[('sbdche', 'val_2012_bboxwh')])
        else:
            val_dataset = gdata.VOC_Val_Detection(
                splits=[('sbdche', 'val'+'_'+'8'+'_bboxwh')])
        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
        val_polygon_metric = VOC07PolygonMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco':
        val_dataset = COCOInstance(root='/home/tutian/dataset/', skip_empty=False)
        val_metric = COCOInstanceMetric(val_dataset, 'test_cocoapi', method='var')
        val_polygon_metric = None
    else:
        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
    return val_dataset, val_metric, val_polygon_metric