示例#1
0
def retrain(model='./models/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite',
            out_file='./models/classify.tflite',
            map_file='./models/map.json'):
    train_dict = defaultdict(lambda: [])
    train_input = []
    labels_map = {}
    train_set = defaultdict(lambda: [])
    pics = TinyDB("./pics.json")
    for pic in pics:
        train_set[pic["class"]].append(pic["img"])
        print(pic)
    samples = pics.all()
    for class_id, (set) in enumerate(train_set):
        print('Processing Class: ', set)
        ret = []
        for filename in train_set[set]:
            img = Image.open("./pics/{}.jpg".format(filename)).resize(
                (224, 224))
            ret.append(np.asarray(img).flatten())
        train_input.append(np.array(ret))
        labels_map[class_id] = set

    if (len(samples) == 0) or not (("background" in train_set.values()) and
                                   ("detection" in train_set.keys())):
        return False

    else:
        engine = ImprintingEngine(model)
        label_map = engine.TrainAll(train_input)
        with open(map_file, 'w') as outfile:
            json.dump(label_map, outfile)

        engine.SaveModel(out_file)
        return True
def main():
  args = _ParseArgs()

  print('---------------      Parsing data set    -----------------')
  print('Dataset path:', args.data)
  train_set, test_set = _ReadData(args.data, args.test_ratio)
  print('Image list successfully parsed! Category Num = ', len(train_set))
  shape = _GetRequiredShape(args.model_path)

  print('---------------- Processing training data ----------------')
  print('This process may take more than 30 seconds.')
  train_input = []
  labels_map = _ReadLabel(args.label)
  class_id = len(labels_map)

  for (category, image_list) in (train_set.items()):
    print('Processing category:', category)
    train_input.append(
        _PrepareImages(
        image_list, os.path.join(args.data, category), shape)
    )

    labels_map[class_id] = category
    class_id += 1

  print('----------------      Start training     -----------------')
  engine = ImprintingEngine(args.model_path, keep_classes=args.keep_classes)
  engine.TrainAll(train_input)

  print('----------------     Training finished!  -----------------')
  engine.SaveModel(args.output)
  print('Model saved as : ', args.output)
  _SaveLabels(labels_map, args.output)


  print('------------------   Start evaluating   ------------------')
  engine = ClassificationEngine(args.output)
  top_k = 12
  correct = [0] * top_k
  wrong = [0] * top_k
  for category, image_list in test_set.items():
    print('Evaluating category [', category, ']')
    for img_name in image_list:
      img = Image.open(os.path.join(args.data, category, img_name))
      candidates = engine.ClassifyWithImage(img, threshold=0.01, top_k=top_k)
      recognized = False
      for i in range(top_k):
        if i < len(candidates) and labels_map[candidates[i][0]] == category:
          recognized = True
        if recognized:
          correct[i] = correct[i] + 1
        else:
          wrong[i] = wrong[i] + 1

  print('----------------     Evaluation result   -----------------')
  for i in range(top_k):
    print('Top {} : {:.0%}'.format(i+1, correct[i] / (correct[i] + wrong[i])))
示例#3
0
def retrain(
        model='./models/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite',
        out_file='./models/classify.tflite',
        map_file='./models/map.json'):
    train_dict = defaultdict(lambda: [])
    pics = TinyDB("./pics.json")
    samples = pics.all()
    for s in samples:
        img = Image.open("./pics/{}.jpg".format(s["img"])).resize((224, 224))
        train_dict[s["class"]].append(np.array(img).flatten())

    if (len(samples) == 0) or not (("background" in train_dict.keys()) and
                                   ("detection" in train_dict.keys())):
        return False

    else:
        engine = ImprintingEngine(model)
        label_map = engine.TrainAll(train_dict)
        with open(map_file, 'w') as outfile:
            json.dump(label_map, outfile)

        engine.SaveModel(out_file)
        return True