def main(): args = _parse_args() print('--------------- Parsing data set -----------------') print('Dataset path:', args.data) train_set, test_set = _read_data(args.data, args.test_ratio) print('Image list successfully parsed! Category Num = ', len(train_set)) shape = _get_required_shape(args.model_path) print('---------------- Processing training data ----------------') print('This process may take more than 30 seconds.') train_input = [] labels_map = {} for class_id, (category, image_list) in enumerate(train_set.items()): print('Processing category:', category) train_input.append( _prepare_images(image_list, os.path.join(args.data, category), shape)) labels_map[class_id] = category print('---------------- Start training -----------------') engine = ImprintingEngine(args.model_path) engine.train_all(train_input) print('---------------- Training finished! -----------------') engine.save_model(args.output) print('Model saved as : ', args.output) _save_labels(labels_map, args.output) print('------------------ Start evaluating ------------------') engine = ClassificationEngine(args.output) top_k = 5 correct = [0] * top_k wrong = [0] * top_k for category, image_list in test_set.items(): print('Evaluating category [', category, ']') for img_name in image_list: img = Image.open(os.path.join(args.data, category, img_name)) candidates = engine.classify_with_image(img, threshold=0.1, top_k=top_k) recognized = False for i in range(top_k): if i < len(candidates) and labels_map[candidates[i] [0]] == category: recognized = True if recognized: correct[i] = correct[i] + 1 else: wrong[i] = wrong[i] + 1 print('---------------- Evaluation result -----------------') for i in range(top_k): print('Top {} : {:.0%}'.format(i + 1, correct[i] / (correct[i] + wrong[i])))
def retrain_model(props): """ This function is using the Imprinting technique to retrain the model by only changing the last layer. All classes will be abandoned while training multiple users """ MODEL_PATH = props['classification']['default_path'] click.echo('Parsing data for retraining...') train_set = {} test_set = {} for user in props['user'].keys(): image_dir = props['user'][user]['images'] images = [ f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f)) ] if images: # allocate the number of images for training an validation net_pictures = len(images) click.echo( click.style('We found {} pictures for {}'.format( net_pictures, user), fg='green')) while True: k = int( click.prompt( 'How many pictures do you want for validating the training?' )) if k > 0.25 * net_pictures: click.echo( click.style( 'At most 25% ({} pictures) of the training data can be used for testing the model!' .format(int(0.25 * net_pictures)), fg='yellow')) elif k < 2: click.echo( click.style( 'At least 3 pictues must be used for testing the model!', fg='yellow')) else: break test_set[user] = images[:k] assert test_set, 'No images to test [{}]'.format(user) train_set[user] = images[k:] assert train_set, 'No images to train [{}]'.format(user) #get shape of model to retrain tmp = BasicEngine(MODEL_PATH) input_tensor = tmp.get_input_tensor_shape() shape = (input_tensor[2], input_tensor[1]) #rezising pictures and creating new labels map train_input = [] labels_map = {} for user_id, (user, image_list) in enumerate(train_set.items()): ret = [] for filename in image_list: with Image.open( os.path.join(props['user'][user]['images'], filename)) as img: img = img.convert('RGB') img = img.resize(shape, Image.NEAREST) ret.append(np.asarray(img).flatten()) train_input.append(np.array(ret)) labels_map[user_id] = user #Train model click.echo('Start training') engine = ImprintingEngine(MODEL_PATH, keep_classes=False) engine.train_all(train_input) click.echo(click.style('Training finished!', fg='green')) #gethering old model files old_model = props['classification']['path'] old_labels = props['classification']['labels'] #saving new model props['classification']['path'] = './Models/model{}.tflite'.format(''.join( ['_' + u for u in labels_map.values()])) engine.save_model(props['classification']['path']) #saving labels props['classification']['labels'] = props['classification'][ 'path'].replace('classification', 'labels').replace('tflite', 'json') with open(props['classification']['labels'], 'w') as f: json.dump(labels_map, f, indent=4) #Evaluating how well the retrained model performed click.echo('Start evaluation') engine = ClassificationEngine(props['classification']['path']) top_k = 5 correct = [0] * top_k wrong = [0] * top_k for user, image_list in test_set.items(): for img_name in image_list: img = Image.open( os.path.join(props['user'][user]['images'], img_name)) candidates = engine.classify_with_image(img, threshold=0.1, top_k=top_k) recognized = False for i in range(top_k): if i < len(candidates) and user == labels_map[candidates[i] [0]]: recognized = True if recognized: correct[i] = correct[i] + 1 else: wrong[i] = wrong[i] + 1 click.echo('Evaluation Results:') for i in range(top_k): click.echo('Top {} : {:.0%}'.format( i + 1, correct[i] / (correct[i] + wrong[i]))) # TODO highlight with colors how well it perforemed if not old_model == props['classification'][ 'path'] and not old_labels == props['classification'][ 'labels'] and (os.path.exists(old_labels) or os.path.exists(old_model)): if not click.confirm('Do you want to keep old models?'): os.remove(old_model) os.remove(old_labels) click.echo(click.style('Old models removed.', fg='green')) #saving properties save_properties(props)