def test_keep_categories_with_unique_id(self):
     label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
     label_map_string = """
   item {
     id:2
     name:'cat'
   }
   item {
     id:1
     name:'child'
   }
   item {
     id:1
     name:'person'
   }
   item {
     id:1
     name:'n00007846'
   }
 """
     text_format.Merge(label_map_string, label_map_proto)
     categories = label_map_util.convert_label_map_to_categories(
         label_map_proto, max_num_classes=3)
     self.assertListEqual([{
         'id': 2,
         'name': u'cat'
     }, {
         'id': 1,
         'name': u'child'
     }], categories)
 def _generate_label_map(self, num_classes):
     label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
     for i in range(1, num_classes + 1):
         item = label_map_proto.item.add()
         item.id = i
         item.name = 'label_' + str(i)
         item.display_name = str(i)
     return label_map_proto
Esempio n. 3
0
def load_labelmap(path):
    """Loads label map proto.

  Args:
    path: path to StringIntLabelMap proto text file.
  Returns:
    a StringIntLabelMapProto
  """
    with tf.gfile.GFile(path, 'r') as fid:
        label_map_string = fid.read()
        label_map = string_int_label_map_pb2.StringIntLabelMap()
        try:
            text_format.Merge(label_map_string, label_map)
        except text_format.ParseError:
            label_map.ParseFromString(label_map_string)
    _validate_label_map(label_map)
    return label_map