def retrain(model='./models/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite', out_file='./models/classify.tflite', map_file='./models/map.json'): train_dict = defaultdict(lambda: []) train_input = [] labels_map = {} train_set = defaultdict(lambda: []) pics = TinyDB("./pics.json") for pic in pics: train_set[pic["class"]].append(pic["img"]) print(pic) samples = pics.all() for class_id, (set) in enumerate(train_set): print('Processing Class: ', set) ret = [] for filename in train_set[set]: img = Image.open("./pics/{}.jpg".format(filename)).resize( (224, 224)) ret.append(np.asarray(img).flatten()) train_input.append(np.array(ret)) labels_map[class_id] = set if (len(samples) == 0) or not (("background" in train_set.values()) and ("detection" in train_set.keys())): return False else: engine = ImprintingEngine(model) label_map = engine.TrainAll(train_input) with open(map_file, 'w') as outfile: json.dump(label_map, outfile) engine.SaveModel(out_file) return True
def main(): args = _ParseArgs() print('--------------- Parsing data set -----------------') print('Dataset path:', args.data) train_set, test_set = _ReadData(args.data, args.test_ratio) print('Image list successfully parsed! Category Num = ', len(train_set)) shape = _GetRequiredShape(args.model_path) print('---------------- Processing training data ----------------') print('This process may take more than 30 seconds.') train_input = [] labels_map = _ReadLabel(args.label) class_id = len(labels_map) for (category, image_list) in (train_set.items()): print('Processing category:', category) train_input.append( _PrepareImages( image_list, os.path.join(args.data, category), shape) ) labels_map[class_id] = category class_id += 1 print('---------------- Start training -----------------') engine = ImprintingEngine(args.model_path, keep_classes=args.keep_classes) engine.TrainAll(train_input) print('---------------- Training finished! -----------------') engine.SaveModel(args.output) print('Model saved as : ', args.output) _SaveLabels(labels_map, args.output) print('------------------ Start evaluating ------------------') engine = ClassificationEngine(args.output) top_k = 12 correct = [0] * top_k wrong = [0] * top_k for category, image_list in test_set.items(): print('Evaluating category [', category, ']') for img_name in image_list: img = Image.open(os.path.join(args.data, category, img_name)) candidates = engine.ClassifyWithImage(img, threshold=0.01, top_k=top_k) recognized = False for i in range(top_k): if i < len(candidates) and labels_map[candidates[i][0]] == category: recognized = True if recognized: correct[i] = correct[i] + 1 else: wrong[i] = wrong[i] + 1 print('---------------- Evaluation result -----------------') for i in range(top_k): print('Top {} : {:.0%}'.format(i+1, correct[i] / (correct[i] + wrong[i])))
def retrain( model='./models/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite', out_file='./models/classify.tflite', map_file='./models/map.json'): train_dict = defaultdict(lambda: []) pics = TinyDB("./pics.json") samples = pics.all() for s in samples: img = Image.open("./pics/{}.jpg".format(s["img"])).resize((224, 224)) train_dict[s["class"]].append(np.array(img).flatten()) if (len(samples) == 0) or not (("background" in train_dict.keys()) and ("detection" in train_dict.keys())): return False else: engine = ImprintingEngine(model) label_map = engine.TrainAll(train_dict) with open(map_file, 'w') as outfile: json.dump(label_map, outfile) engine.SaveModel(out_file) return True
class DemoImprintingEngine(object): """Engine wrapping from Imprinting Engine for demo usage.""" def __init__(self, model_path, output_path, keep_classes, batch_size): """Creates a ImprintingEngine with given model and labels. Args: model_path: String, path to TF-Lite Flatbuffer file. output_path: String, path to output tflite file. keep_classes: Bool, whether to keep base model classes. batch_size: Int, batch size for engine to train once. Raises: ValueError: An error occurred when model output is invalid. """ self._model_path = model_path self._keep_classes = keep_classes self._output_path = output_path self._batch_size = batch_size self._required_image_size = self.getRequiredInputShape() self._example_count = 0 self._imprinting_engine = ImprintingEngine(self._model_path, keep_classes=self._keep_classes) self.clear() def getRequiredInputShape(self): """ Get the required input shape for the model. """ basic_engine = BasicEngine(self._model_path) input_tensor_shape = basic_engine.get_input_tensor_shape() if (input_tensor_shape.size != 4 or input_tensor_shape[3] != 3 or input_tensor_shape[0] != 1): raise RuntimeError( 'Invalid input tensor shape! Expected: [1, height, width, 3]') return (input_tensor_shape[2], input_tensor_shape[1]) def clear(self): """ Save the trained model. Clear the store: forgets all stored images. """ # Save the trained model. if self._example_count > 0: self._imprinting_engine.SaveModel(self._output_path) # The size of all the image store. self._example_count = 0 # The ImprintingEngine does not allow training images with too large labels. # For example, with an existing model with 3 output classes, there are two # options: training existing classes [0, 1, 2] or training exactly the next # class [3]. # We have two maps to store the mappings from button_label to real_label, # and vice versa. self._label_map_button2real = {} self._label_map_real2button = {} self._max_real_label = 0 # A map with real label as key, and training images as value. self._image_map = defaultdict(list) def trainAndUpdateModel(self): """Train a batch of images and update the engines.""" for label_real in range(0, self._max_real_label): if label_real in self._image_map: self._imprinting_engine.Train(np.array(self._image_map[label_real]), label_real) self._image_map = defaultdict(list) #reset def addImage(self, img, label_button): """Add an image to the store.""" # Update the label map. if label_button not in self._label_map_button2real: self._label_map_button2real[label_button] = self._max_real_label self._label_map_real2button[self._max_real_label] = label_button self._max_real_label += 1 label_real = self._label_map_button2real[label_button] self._example_count += 1 resized_img = img.resize(self._required_image_size, Image.NEAREST) self._image_map[label_real].append(np.asarray(resized_img).flatten()) # Train a batch of images. if sum(len(v) for v in self._image_map.values()) == self._batch_size: self.trainAndUpdateModel() def classify(self, img): # If we have nothing trained, the answer is None if self.exampleCount() == 0: return None resized_img = img.resize(self._required_image_size, Image.NEAREST) scores = self._imprinting_engine.ClassifyWithResizedImage(resized_img, top_k=1) return self._label_map_real2button[scores[0][0]] def exampleCount(self): """Just returns the size of the image store.""" return self._example_count