示例#1
0
def create_voc_labelmap(devkit_dir, save_dir, pascal_year='2012'):
    """
        Creates and saves the Pascal VOC labelmap
        :param devkit_dir: Root folder containing the Pascal VOC dataset.
        :param save_dir: Folder where the labelmap should be stored.
        :param pascal_year: Complete year ( YYYY ) of the VOC dataset.
        :return: None
        """
    if not os.path.isdir(devkit_dir):
        raise OSError('The Pascal VOC Devkit folder {} was not found.'.format(
            devkit_dir))

    search_folder = os.path.join(devkit_dir, 'VOC{}'.format(pascal_year),
                                 'ImageSets', 'Main')
    categories_files = glob.glob(os.path.join(search_folder, '*_val.txt'))
    categories_files = list(
        map(lambda x: os.path.basename(x), categories_files))
    REGEX_COMPILED = re.compile(r'^(.*)_val.txt$')
    categories = list(
        map(lambda x: re.match(REGEX_COMPILED, x)[1], categories_files))
    items = list()
    for index, category in enumerate(categories):
        items.append(
            labelmap_pb2.LabelMapItem(name=category,
                                      display_name=category,
                                      id=index + 1))

    labelmap = labelmap_pb2.LabelMap(item=items)
    labelmap_filename = os.path.join(
        save_dir, 'VOC_{}_labelmap.pbtxt'.format(pascal_year))
    with open(labelmap_filename, 'w') as fid:
        fid.write(str(labelmap))

    return None
示例#2
0
def create_openimages_labelmap(openimages_base_dir, save_dir):
    class_descriptions_file = os.path.join(openimages_base_dir,
                                           'class-descriptions-boxable.csv')
    segmentation_classes_file = os.path.join(openimages_base_dir,
                                             'classes-segmentation.txt')

    description_dict = dict()
    for line in open(class_descriptions_file, 'r'):
        line = line.strip()
        class_name, display_name = line.split(',')
        description_dict[class_name] = display_name

    segmentation_classes = list()
    for line in open(segmentation_classes_file, 'r'):
        line = line.strip()
        segmentation_classes.append(line)

    items = list()
    for index, segmentation_class in enumerate(segmentation_classes):
        display_name = description_dict[segmentation_class]
        items.append(
            labelmap_pb2.LabelMapItem(name=segmentation_class,
                                      display_name=display_name,
                                      id=index + 1))

    labelmap = labelmap_pb2.LabelMap(item=items)
    labelmap_filename = os.path.join(save_dir, 'openimages_v5_labelmap.pbtxt')
    with open(labelmap_filename, 'w') as fid:
        fid.write(str(labelmap))

    return None
示例#3
0
def create_coco_labelmap(coco_data_dir, save_dir, coco_year='2017'):
    """
    Creates and saves the MSCOCO labelmap
    :param coco_data_dir: Root folder containing the MSCOCO dataset.
    :param save_dir: Folder where the labelmap should be stored.
    :param coco_year: Complete year ( YYYY ) of the MSCOCO dataset.
    :return: None
    """
    if not os.path.isdir(coco_data_dir):
        raise OSError(
            'The COCO data folder {} was not found.'.format(coco_data_dir))

    annotation_file_name = get_coco_annotation_file(coco_data_dir, coco_year,
                                                    'train')

    coco = get_coco_object(annotation_file_name)
    cats = coco.loadCats(coco.getCatIds())
    categories = [x['name'] for x in cats]
    items = list()
    for index, category in enumerate(categories):
        items.append(
            labelmap_pb2.LabelMapItem(name=category,
                                      display_name=category,
                                      id=index + 1))

    labelmap = labelmap_pb2.LabelMap(item=items)

    labelmap_filename = os.path.join(
        save_dir, 'coco_{}_labelmap.pbtxt'.format(coco_year))
    with open(labelmap_filename, 'w') as fid:
        fid.write(str(labelmap))

    return None
示例#4
0
def load_labelmap(labelmapfile):
    """
    Loads a labelmap from text file
    :param labelmapfile: Full path to the labelmap text file
    :return: A message of type labelmap_pb2.LabelMap()
    :raises OSError if the labelmap text file was not found.
            ValueError if the labelmap is not valid.
    """
    if not os.path.exists(labelmapfile):
        raise OSError(
            'The labelmap file {} was not found.'.format(labelmapfile))

    with tf.io.gfile.GFile(labelmapfile) as fid:
        labelmap_string = fid.read()
        labelmap = labelmap_pb2.LabelMap()
        text_format.Merge(labelmap_string, labelmap)

    _validate_label_map(labelmap)
    return labelmap
示例#5
0
def create_lvis_labelmap(lvis_annotation_dir, save_dir, lvis_version='0.5'):
    annotation_file = get_lvis_annotation_file(lvis_annotation_dir, 'train',
                                               lvis_version)
    lvis_obj = get_lvis_object(annotation_file)
    category_dict = lvis_obj.cats
    items = list()
    for class_label in category_dict.keys():
        class_name = category_dict[class_label]['name']
        items.append(
            labelmap_pb2.LabelMapItem(name=class_name,
                                      display_name=class_name,
                                      id=class_label))

    labelmap = labelmap_pb2.LabelMap(item=items)

    labelmap_filename = os.path.join(
        save_dir, 'lvis_v{}_labelmap.pbtxt'.format(lvis_version))
    with open(labelmap_filename, 'w') as fid:
        fid.write(str(labelmap))

    return None
示例#6
0
                    help='File containing class information.')
parser.add_argument('--save_name',
                    required=True,
                    type=str,
                    help='Full path to the filename to be stored as labelmap')

if __name__ == "__main__":
    args = parser.parse_args()

    if not os.path.exists(args['class_def_file']):
        raise OSError('The class definition file {} was not found.'.format(
            args['class_def_file']))

    items = list()

    for line in open(args['class_def_file'], 'r'):
        line = line.strip()
        line = line.split(',')
        name = line[0]
        display_name = line[1]
        id = line[2]
        items.append(
            labelmap_pb2.LabelMapItem(name=name,
                                      display_name=display_name,
                                      id=id))

    labelmap = labelmap_pb2.LabelMap(item=items)

    with open(args['save_name'], 'w') as fid:
        fid.write(str(labelmap))