def get_dataset(dataset, args): if dataset.lower() == 'polar': train_dataset = gdata.POLARDetection(split='train') val_dataset = gdata.POLARDetection(split='val') val_metric = VOCMApMetric(iou_thresh=0.5, class_names=val_dataset.classes) elif dataset.lower() == 'voca': train_dataset = gdata.VOCAction(split='train') val_dataset = gdata.VOCAction(split='val') val_metric = VOCMApMetric(iou_thresh=0.5, class_names=val_dataset.classes) elif dataset.lower() == 'voc': train_dataset = gdata.VOCDetection( splits=[(2007, 'trainval'), (2012, 'trainval')]) val_dataset = gdata.VOCDetection(splits=[(2007, 'test')]) val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) elif dataset.lower() == 'coco': train_dataset = gdata.COCODetection(splits='instances_train2017', use_crowd=False) val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False) val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True) else: raise NotImplementedError( 'Dataset: {} not implemented.'.format(dataset)) return train_dataset, val_dataset, val_metric
def get_dataset(dataset, args): if dataset.lower() == 'voca': train_dataset = gdata.VOCAction(split='train', augment_box=True, load_box=True) val_dataset = gdata.VOCAction(split='val', load_box=True) val_metric = VOCMultiClsMApMetric(class_names=val_dataset.classes, ignore_label=-1, voc_action_type=True) else: raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) return train_dataset, val_dataset, val_metric
def get_dataset(dataset, args): if dataset.lower() == 'voca': val_dataset = gdata.VOCAction(split='test') elif dataset.lower() == 'st40': val_dataset = gdata.Stanford40Action(split='test') elif dataset.lower() == 'hico': val_dataset = gdata.HICOClassification(split='all', preload_label=False) else: raise NotImplementedError( 'Dataset: {} not implemented.'.format(dataset)) return val_dataset