Example #1
0
def convert_classes(classes, start=1):
    msg = StringIntLabelMap()
    for id, name in enumerate(classes, start=start):
        msg.item.append(StringIntLabelMapItem(id=id, name=name))

    text = str(text_format.MessageToBytes(msg, as_utf8=True), 'utf-8')
    return text
Example #2
0
 def _save_label_dict_to_file(label_dict: dict, label_map_path: str):
     label_map = StringIntLabelMap()
     for label, id in label_dict.items():
         label_map.item.append(StringIntLabelMapItem(id=id, name=label))
     with open(label_map_path, "w") as f:
         f.write(
             str(text_format.MessageToBytes(label_map, as_utf8=True),
                 "utf-8"))
Example #3
0
    def create_labelmap_pbtxt(self, path):
        msg = StringIntLabelMap()
        for category in self.categories:
            if category['enabled']:
                msg.item.append(StringIntLabelMapItem(id=category['id'], name=category['name']))

        txt = str(text_format.MessageToBytes(msg, as_utf8=True), 'utf-8')
        with open(path, 'w') as f:
                f.write(txt)
Example #4
0
def make_label_map(classes, output_path):
    Path(output_path).mkdir(exist_ok=True, parents=True)

    msg = StringIntLabelMap()
    for id, name in enumerate(classes, start=1):
        msg.item.append(StringIntLabelMapItem(id=id, name=name))

    text = str(text_format.MessageToBytes(msg, as_utf8=True), 'utf-8')

    with open(os.path.join(output_path, 'pascal_label_map.pbtxt'), 'w') as f:
        f.write(text)
def label_map_to_file(label_map: Dict[str, int], filepath: Union[str, Path]):
    msg = StringIntLabelMap()
    label_map = {
        label: i
        for label, i in sorted(label_map.items(), key=lambda item: item[1])
    }
    for label, i in label_map.items():
        # pylint: disable=no-member
        msg.item.append(StringIntLabelMapItem(id=i, name=label))

    text = str(text_format.MessageToBytes(msg, as_utf8=True), 'utf-8')
    with open(filepath, 'w') as out:
        out.write(text)
    logger.info(f'label_map saved to {filepath}')
Example #6
0
def get_category_mapping(label_names_file):
    """Creates dictionary of COCO compatible categories keyed by category id.

      Returns:
        category_index: a dict containing the same entries as categories, but keyed
          by the 'id' field of each category.
      """
    cat2idx = {}
    idx2cat = {}
    with open(label_names_file, 'r') as f:
        label_maps = text_format.Parse(f.read(), StringIntLabelMap())
        for label_map in label_maps.item:
            cat2idx[label_map.name] = label_map.id
            idx2cat[label_map.id] = label_map.name
    return cat2idx, idx2cat
Example #7
0
def save_metadata_label_map(output_path: str, label_map_dict: Dict[str, int]):
    """
    Generate ProtocolBuffer file with the classes map for the detector
    :param output_path:
    :param label_map_dict:
    :return:
    """
    from object_detection.protos.string_int_label_map_pb2 import StringIntLabelMap

    label_map = StringIntLabelMap()
    with open(os.path.join(output_path, GENERATOR_LABEL_FILE), "wb") as f:
        for name, id in label_map_dict.items():
            label_item = label_map.item.add()
            label_item.name = name
            label_item.id = id
        f.write(str(label_map).encode())
Example #8
0
def create_labes(label_basepath, categories):
    path_pbtxt = os.path.join(label_basepath, 'label_map.pbtxt')

    msg = StringIntLabelMap()
    for category in categories:
        if category['enabled']:
            msg.item.append(
                StringIntLabelMapItem(id=category['id'],
                                      name=category['name']))

    txt = str(text_format.MessageToBytes(msg, as_utf8=True), 'utf-8')
    print(txt)
    with open(path_pbtxt, 'w') as f:
        f.write(txt)

    path_txt = os.path.join(label_basepath, 'label_map.txt')
    f = open(path_txt, 'w')
    for category in categories:
        if category['enabled']:
            f.write(category['name'] + '\n')
    f.close()
Example #9
0
from object_detection.protos.string_int_label_map_pb2 import StringIntLabelMap
from google.protobuf import text_format

label_map_path = './imagenet_label_map.pbtxt'

x = StringIntLabelMap()
fid = open(label_map_path, 'r')
text_format.Merge(fid.read(), x)
fid.close()
import os
import sys

from google.protobuf import text_format
from object_detection.protos.pipeline_pb2 import TrainEvalPipelineConfig
from object_detection.protos.string_int_label_map_pb2 import StringIntLabelMap, StringIntLabelMapItem

LABELMAP_PATH, CLASS_NAME, PIPELINE_PATH, SOURCE_MODEL_PATH, SOURCE_PIPELINE_PATH, TRAIN_TFRECORD_PATH, TEST_TFRECORD_PATH = sys.argv[
    1:8]

#Save labelmap
labelmap = StringIntLabelMap()
labelmap.item.append(StringIntLabelMapItem(id=1, name=CLASS_NAME))
with open(LABELMAP_PATH, 'w') as f:
    f.write(text_format.MessageToString(labelmap))

#Edit source pipeline and save as pipeline
pipeline = TrainEvalPipelineConfig()
with open(SOURCE_PIPELINE_PATH, 'r') as f:
    text_format.Merge(f.read(), pipeline)
pipeline.model.ssd.num_classes = 1
pipeline.train_input_reader.label_map_path = LABELMAP_PATH
pipeline.train_input_reader.tf_record_input_reader.input_path[
    0] = TRAIN_TFRECORD_PATH
pipeline.train_config.fine_tune_checkpoint = os.path.join(
    SOURCE_MODEL_PATH, 'checkpoint/ckpt-0')
pipeline.train_config.fine_tune_checkpoint_type = 'detection'
pipeline.train_config.batch_size = 4
pipeline.train_config.use_bfloat16 = False
pipeline.eval_input_reader[0].label_map_path = LABELMAP_PATH
pipeline.eval_input_reader[0].tf_record_input_reader.input_path[