def retrain(model='./models/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite',
            out_file='./models/classify.tflite',
            map_file='./models/map.json'):
    train_dict = defaultdict(lambda: [])
    train_input = []
    labels_map = {}
    train_set = defaultdict(lambda: [])
    pics = TinyDB("./pics.json")
    for pic in pics:
        train_set[pic["class"]].append(pic["img"])
        print(pic)
    samples = pics.all()
    for class_id, (set) in enumerate(train_set):
        print('Processing Class: ', set)
        ret = []
        for filename in train_set[set]:
            img = Image.open("./pics/{}.jpg".format(filename)).resize(
                (224, 224))
            ret.append(np.asarray(img).flatten())
        train_input.append(np.array(ret))
        labels_map[class_id] = set

    if (len(samples) == 0) or not (("background" in train_set.values()) and
                                   ("detection" in train_set.keys())):
        return False

    else:
        engine = ImprintingEngine(model)
        label_map = engine.TrainAll(train_input)
        with open(map_file, 'w') as outfile:
            json.dump(label_map, outfile)

        engine.SaveModel(out_file)
        return True
def main():
  args = _ParseArgs()

  print('---------------      Parsing data set    -----------------')
  print('Dataset path:', args.data)
  train_set, test_set = _ReadData(args.data, args.test_ratio)
  print('Image list successfully parsed! Category Num = ', len(train_set))
  shape = _GetRequiredShape(args.model_path)

  print('---------------- Processing training data ----------------')
  print('This process may take more than 30 seconds.')
  train_input = []
  labels_map = _ReadLabel(args.label)
  class_id = len(labels_map)

  for (category, image_list) in (train_set.items()):
    print('Processing category:', category)
    train_input.append(
        _PrepareImages(
        image_list, os.path.join(args.data, category), shape)
    )

    labels_map[class_id] = category
    class_id += 1

  print('----------------      Start training     -----------------')
  engine = ImprintingEngine(args.model_path, keep_classes=args.keep_classes)
  engine.TrainAll(train_input)

  print('----------------     Training finished!  -----------------')
  engine.SaveModel(args.output)
  print('Model saved as : ', args.output)
  _SaveLabels(labels_map, args.output)


  print('------------------   Start evaluating   ------------------')
  engine = ClassificationEngine(args.output)
  top_k = 12
  correct = [0] * top_k
  wrong = [0] * top_k
  for category, image_list in test_set.items():
    print('Evaluating category [', category, ']')
    for img_name in image_list:
      img = Image.open(os.path.join(args.data, category, img_name))
      candidates = engine.ClassifyWithImage(img, threshold=0.01, top_k=top_k)
      recognized = False
      for i in range(top_k):
        if i < len(candidates) and labels_map[candidates[i][0]] == category:
          recognized = True
        if recognized:
          correct[i] = correct[i] + 1
        else:
          wrong[i] = wrong[i] + 1

  print('----------------     Evaluation result   -----------------')
  for i in range(top_k):
    print('Top {} : {:.0%}'.format(i+1, correct[i] / (correct[i] + wrong[i])))
def retrain(
        model='./models/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite',
        out_file='./models/classify.tflite',
        map_file='./models/map.json'):
    train_dict = defaultdict(lambda: [])
    pics = TinyDB("./pics.json")
    samples = pics.all()
    for s in samples:
        img = Image.open("./pics/{}.jpg".format(s["img"])).resize((224, 224))
        train_dict[s["class"]].append(np.array(img).flatten())

    if (len(samples) == 0) or not (("background" in train_dict.keys()) and
                                   ("detection" in train_dict.keys())):
        return False

    else:
        engine = ImprintingEngine(model)
        label_map = engine.TrainAll(train_dict)
        with open(map_file, 'w') as outfile:
            json.dump(label_map, outfile)

        engine.SaveModel(out_file)
        return True
class DemoImprintingEngine(object):
  """Engine wrapping from Imprinting Engine for demo usage."""

  def __init__(self, model_path, output_path, keep_classes, batch_size):
    """Creates a ImprintingEngine with given model and labels.

    Args:
      model_path: String, path to TF-Lite Flatbuffer file.
      output_path: String, path to output tflite file.
      keep_classes: Bool, whether to keep base model classes.
      batch_size: Int, batch size for engine to train once.

    Raises:
      ValueError: An error occurred when model output is invalid.
    """
    self._model_path = model_path
    self._keep_classes = keep_classes
    self._output_path = output_path
    self._batch_size = batch_size
    self._required_image_size = self.getRequiredInputShape()
    self._example_count = 0
    self._imprinting_engine = ImprintingEngine(self._model_path, keep_classes=self._keep_classes)
    self.clear()

  def getRequiredInputShape(self):
    """
    Get the required input shape for the model.
    """
    basic_engine = BasicEngine(self._model_path)
    input_tensor_shape = basic_engine.get_input_tensor_shape()
    if (input_tensor_shape.size != 4 or input_tensor_shape[3] != 3 or
        input_tensor_shape[0] != 1):
      raise RuntimeError(
          'Invalid input tensor shape! Expected: [1, height, width, 3]')
    return (input_tensor_shape[2], input_tensor_shape[1])

  def clear(self):
    """
    Save the trained model.
    Clear the store: forgets all stored images.
    """
    # Save the trained model.
    if self._example_count > 0:
      self._imprinting_engine.SaveModel(self._output_path)
    # The size of all the image store.
    self._example_count = 0

    # The ImprintingEngine does not allow training images with too large labels.
    # For example, with an existing model with 3 output classes, there are two
    # options: training existing classes [0, 1, 2] or training exactly the next
    # class [3].

    # We have two maps to store the mappings from button_label to real_label,
    # and vice versa.
    self._label_map_button2real = {}
    self._label_map_real2button = {}
    self._max_real_label = 0
    # A map with real label as key, and training images as value.
    self._image_map = defaultdict(list)

  def trainAndUpdateModel(self):
    """Train a batch of images and update the engines."""
    for label_real in range(0, self._max_real_label):
      if label_real in self._image_map:
        self._imprinting_engine.Train(np.array(self._image_map[label_real]), label_real)
    self._image_map = defaultdict(list)  #reset

  def addImage(self, img, label_button):
    """Add an image to the store."""
    # Update the label map.
    if label_button not in self._label_map_button2real:
      self._label_map_button2real[label_button] = self._max_real_label
      self._label_map_real2button[self._max_real_label] = label_button
      self._max_real_label += 1
    label_real = self._label_map_button2real[label_button]
    self._example_count += 1
    resized_img = img.resize(self._required_image_size, Image.NEAREST)
    self._image_map[label_real].append(np.asarray(resized_img).flatten())

    # Train a batch of images.
    if sum(len(v) for v in self._image_map.values()) == self._batch_size:
      self.trainAndUpdateModel()

  def classify(self, img):
    # If we have nothing trained, the answer is None
    if self.exampleCount() == 0:
        return None
    resized_img = img.resize(self._required_image_size, Image.NEAREST)
    scores = self._imprinting_engine.ClassifyWithResizedImage(resized_img, top_k=1)
    return self._label_map_real2button[scores[0][0]]

  def exampleCount(self):
    """Just returns the size of the image store."""
    return self._example_count