def main():
    args = _parse_args()
    print('---------------      Parsing data set    -----------------')
    print('Dataset path:', args.data)

    train_set, test_set = _read_data(args.data, args.test_ratio)
    print('Image list successfully parsed! Category Num = ', len(train_set))
    shape = _get_required_shape(args.model_path)

    print('---------------- Processing training data ----------------')
    print('This process may take more than 30 seconds.')
    train_input = []
    labels_map = {}
    for class_id, (category, image_list) in enumerate(train_set.items()):
        print('Processing category:', category)
        train_input.append(
            _prepare_images(image_list, os.path.join(args.data, category),
                            shape))
        labels_map[class_id] = category
    print('----------------      Start training     -----------------')
    engine = ImprintingEngine(args.model_path)
    engine.train_all(train_input)
    print('----------------     Training finished!  -----------------')

    engine.save_model(args.output)
    print('Model saved as : ', args.output)
    _save_labels(labels_map, args.output)

    print('------------------   Start evaluating   ------------------')
    engine = ClassificationEngine(args.output)
    top_k = 5
    correct = [0] * top_k
    wrong = [0] * top_k
    for category, image_list in test_set.items():
        print('Evaluating category [', category, ']')
        for img_name in image_list:
            img = Image.open(os.path.join(args.data, category, img_name))
            candidates = engine.classify_with_image(img,
                                                    threshold=0.1,
                                                    top_k=top_k)
            recognized = False
            for i in range(top_k):
                if i < len(candidates) and labels_map[candidates[i]
                                                      [0]] == category:
                    recognized = True
                if recognized:
                    correct[i] = correct[i] + 1
                else:
                    wrong[i] = wrong[i] + 1
    print('----------------     Evaluation result   -----------------')
    for i in range(top_k):
        print('Top {} : {:.0%}'.format(i + 1,
                                       correct[i] / (correct[i] + wrong[i])))
示例#2
0
def retrain_model(props):
    """
        This function is using the Imprinting technique to retrain the model by only changing the last layer.
        All classes will be abandoned while training multiple users
    """
    MODEL_PATH = props['classification']['default_path']

    click.echo('Parsing data for retraining...')
    train_set = {}
    test_set = {}
    for user in props['user'].keys():
        image_dir = props['user'][user]['images']
        images = [
            f for f in os.listdir(image_dir)
            if os.path.isfile(os.path.join(image_dir, f))
        ]
        if images:
            # allocate the number of images for training an validation
            net_pictures = len(images)
            click.echo(
                click.style('We found {} pictures for {}'.format(
                    net_pictures, user),
                            fg='green'))
            while True:
                k = int(
                    click.prompt(
                        'How many pictures do you want for validating the training?'
                    ))
                if k > 0.25 * net_pictures:
                    click.echo(
                        click.style(
                            'At most 25% ({} pictures) of the training data can be used for testing the model!'
                            .format(int(0.25 * net_pictures)),
                            fg='yellow'))
                elif k < 2:
                    click.echo(
                        click.style(
                            'At least 3 pictues must be used for testing the model!',
                            fg='yellow'))
                else:
                    break

            test_set[user] = images[:k]
            assert test_set, 'No images to test [{}]'.format(user)
            train_set[user] = images[k:]
            assert train_set, 'No images to train [{}]'.format(user)

    #get shape of model to retrain
    tmp = BasicEngine(MODEL_PATH)
    input_tensor = tmp.get_input_tensor_shape()
    shape = (input_tensor[2], input_tensor[1])

    #rezising pictures and creating new labels map
    train_input = []
    labels_map = {}
    for user_id, (user, image_list) in enumerate(train_set.items()):
        ret = []
        for filename in image_list:
            with Image.open(
                    os.path.join(props['user'][user]['images'],
                                 filename)) as img:
                img = img.convert('RGB')
                img = img.resize(shape, Image.NEAREST)
                ret.append(np.asarray(img).flatten())
        train_input.append(np.array(ret))
        labels_map[user_id] = user
    #Train model
    click.echo('Start training')
    engine = ImprintingEngine(MODEL_PATH, keep_classes=False)
    engine.train_all(train_input)
    click.echo(click.style('Training finished!', fg='green'))

    #gethering old model files
    old_model = props['classification']['path']
    old_labels = props['classification']['labels']
    #saving new model
    props['classification']['path'] = './Models/model{}.tflite'.format(''.join(
        ['_' + u for u in labels_map.values()]))
    engine.save_model(props['classification']['path'])
    #saving labels
    props['classification']['labels'] = props['classification'][
        'path'].replace('classification', 'labels').replace('tflite', 'json')
    with open(props['classification']['labels'], 'w') as f:
        json.dump(labels_map, f, indent=4)
    #Evaluating how well the retrained model performed
    click.echo('Start evaluation')
    engine = ClassificationEngine(props['classification']['path'])
    top_k = 5
    correct = [0] * top_k
    wrong = [0] * top_k
    for user, image_list in test_set.items():
        for img_name in image_list:
            img = Image.open(
                os.path.join(props['user'][user]['images'], img_name))
            candidates = engine.classify_with_image(img,
                                                    threshold=0.1,
                                                    top_k=top_k)
            recognized = False
            for i in range(top_k):
                if i < len(candidates) and user == labels_map[candidates[i]
                                                              [0]]:
                    recognized = True
                if recognized:
                    correct[i] = correct[i] + 1
                else:
                    wrong[i] = wrong[i] + 1
        click.echo('Evaluation Results:')
        for i in range(top_k):
            click.echo('Top {} : {:.0%}'.format(
                i + 1, correct[i] / (correct[i] + wrong[i])))
        #  TODO  highlight with colors how well it perforemed

    if not old_model == props['classification'][
            'path'] and not old_labels == props['classification'][
                'labels'] and (os.path.exists(old_labels)
                               or os.path.exists(old_model)):
        if not click.confirm('Do you want to keep old models?'):
            os.remove(old_model)
            os.remove(old_labels)
            click.echo(click.style('Old models removed.', fg='green'))
    #saving properties
    save_properties(props)