def create_voc_labelmap(devkit_dir, save_dir, pascal_year='2012'): """ Creates and saves the Pascal VOC labelmap :param devkit_dir: Root folder containing the Pascal VOC dataset. :param save_dir: Folder where the labelmap should be stored. :param pascal_year: Complete year ( YYYY ) of the VOC dataset. :return: None """ if not os.path.isdir(devkit_dir): raise OSError('The Pascal VOC Devkit folder {} was not found.'.format( devkit_dir)) search_folder = os.path.join(devkit_dir, 'VOC{}'.format(pascal_year), 'ImageSets', 'Main') categories_files = glob.glob(os.path.join(search_folder, '*_val.txt')) categories_files = list( map(lambda x: os.path.basename(x), categories_files)) REGEX_COMPILED = re.compile(r'^(.*)_val.txt$') categories = list( map(lambda x: re.match(REGEX_COMPILED, x)[1], categories_files)) items = list() for index, category in enumerate(categories): items.append( labelmap_pb2.LabelMapItem(name=category, display_name=category, id=index + 1)) labelmap = labelmap_pb2.LabelMap(item=items) labelmap_filename = os.path.join( save_dir, 'VOC_{}_labelmap.pbtxt'.format(pascal_year)) with open(labelmap_filename, 'w') as fid: fid.write(str(labelmap)) return None
def create_openimages_labelmap(openimages_base_dir, save_dir): class_descriptions_file = os.path.join(openimages_base_dir, 'class-descriptions-boxable.csv') segmentation_classes_file = os.path.join(openimages_base_dir, 'classes-segmentation.txt') description_dict = dict() for line in open(class_descriptions_file, 'r'): line = line.strip() class_name, display_name = line.split(',') description_dict[class_name] = display_name segmentation_classes = list() for line in open(segmentation_classes_file, 'r'): line = line.strip() segmentation_classes.append(line) items = list() for index, segmentation_class in enumerate(segmentation_classes): display_name = description_dict[segmentation_class] items.append( labelmap_pb2.LabelMapItem(name=segmentation_class, display_name=display_name, id=index + 1)) labelmap = labelmap_pb2.LabelMap(item=items) labelmap_filename = os.path.join(save_dir, 'openimages_v5_labelmap.pbtxt') with open(labelmap_filename, 'w') as fid: fid.write(str(labelmap)) return None
def create_coco_labelmap(coco_data_dir, save_dir, coco_year='2017'): """ Creates and saves the MSCOCO labelmap :param coco_data_dir: Root folder containing the MSCOCO dataset. :param save_dir: Folder where the labelmap should be stored. :param coco_year: Complete year ( YYYY ) of the MSCOCO dataset. :return: None """ if not os.path.isdir(coco_data_dir): raise OSError( 'The COCO data folder {} was not found.'.format(coco_data_dir)) annotation_file_name = get_coco_annotation_file(coco_data_dir, coco_year, 'train') coco = get_coco_object(annotation_file_name) cats = coco.loadCats(coco.getCatIds()) categories = [x['name'] for x in cats] items = list() for index, category in enumerate(categories): items.append( labelmap_pb2.LabelMapItem(name=category, display_name=category, id=index + 1)) labelmap = labelmap_pb2.LabelMap(item=items) labelmap_filename = os.path.join( save_dir, 'coco_{}_labelmap.pbtxt'.format(coco_year)) with open(labelmap_filename, 'w') as fid: fid.write(str(labelmap)) return None
def create_lvis_labelmap(lvis_annotation_dir, save_dir, lvis_version='0.5'): annotation_file = get_lvis_annotation_file(lvis_annotation_dir, 'train', lvis_version) lvis_obj = get_lvis_object(annotation_file) category_dict = lvis_obj.cats items = list() for class_label in category_dict.keys(): class_name = category_dict[class_label]['name'] items.append( labelmap_pb2.LabelMapItem(name=class_name, display_name=class_name, id=class_label)) labelmap = labelmap_pb2.LabelMap(item=items) labelmap_filename = os.path.join( save_dir, 'lvis_v{}_labelmap.pbtxt'.format(lvis_version)) with open(labelmap_filename, 'w') as fid: fid.write(str(labelmap)) return None
help='File containing class information.') parser.add_argument('--save_name', required=True, type=str, help='Full path to the filename to be stored as labelmap') if __name__ == "__main__": args = parser.parse_args() if not os.path.exists(args['class_def_file']): raise OSError('The class definition file {} was not found.'.format( args['class_def_file'])) items = list() for line in open(args['class_def_file'], 'r'): line = line.strip() line = line.split(',') name = line[0] display_name = line[1] id = line[2] items.append( labelmap_pb2.LabelMapItem(name=name, display_name=display_name, id=id)) labelmap = labelmap_pb2.LabelMap(item=items) with open(args['save_name'], 'w') as fid: fid.write(str(labelmap))