Esempio n. 1
0
def predict(config_path,
            model_path=None,
            labels_json=None,
            data_dir=None,
            group=None):
    configs = tools.read_configs(config_path)

    model_path = tools.str2path(model_path or configs.get('model_path'))
    data_dir = tools.str2path(data_dir or configs.get('data_dir'))
    labels_json = tools.str2path(labels_json or configs.get('labels_json'))
    group = group or configs.get('group')

    model_builder = ModelBuilder(configs,
                                 mode='predict',
                                 model_path=model_path.as_posix())
    model = model_builder.build()

    pred_gen = DataGenerator(configs,
                             image_dir=data_dir,
                             labels_json=labels_json,
                             group=group,
                             mode='predict')

    correct = 0
    false = 0
    for x, y, xnames in pred_gen.flow_from_labels():
        predictions = model.predict(x, verbose=1)
        pred_cls_ids = np.argmax(predictions, axis=1)
        tp = np.sum(pred_cls_ids == y)
        error = len(y) - tp
        correct += tp
        false += error
    accuracy = (1 - false / correct) * 100

    print(f"Correct: {correct}")
    print(f'False: {false}')
    print(f'Accuracy: {accuracy}')
Esempio n. 2
0
def train(config_path,
          train_dir=None,
          val_dir=None,
          output_dir=None,
          train_labels_json=None,
          val_labels_json=None,
          group=None,
          model_name=None,
          model_suffix=None):
    np.random.seed(42)  # for reproducibility
    logger = logging.getLogger('root')
    configs = tools.read_configs(config_path)

    train_dir = tools.str2path(train_dir or configs['train_dir'])
    val_dir = tools.str2path(val_dir or configs['val_dir'])
    train_labels_json = tools.str2path(train_labels_json
                                       or configs['train_labels_json'])
    val_labels_json = tools.str2path(val_labels_json
                                     or configs['val_labels_json'])
    output_dir = tools.str2path(output_dir or configs['output_dir'])
    group = group or configs['group']
    model_name = model_name or configs['model_name']
    model_suffix = model_suffix or configs['model_suffix']
    train_counts = configs.get('train_class_counts')
    val_counts = configs.get('val_class_counts')

    output_dir.mkdir(exist_ok=True)

    model_out_name = f'{model_name}_{group}_{model_suffix}.h5'
    model_path = output_dir / model_out_name

    train_gen = DataGenerator(configs, train_dir, train_labels_json, 'train',
                              group, train_counts)
    val_gen = DataGenerator(configs, val_dir, val_labels_json, 'val', group,
                            val_counts)

    epochs = configs['epochs']
    classes = configs['network_parameters']['classes']
    loss = configs['loss']
    optimizer = configs['optimizer']

    model_builder = ModelBuilder(configs, 'train', model_name, model_path,
                                 classes, loss, optimizer)
    model = model_builder.build()

    checkpoint = keras.callbacks.ModelCheckpoint(model_path.as_posix(),
                                                 monitor='loss',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 save_weights_only=False,
                                                 mode='min')
    logger.info(f"Training model {model_out_name} for {epochs} epochs")
    logger.info(f'Class weights: {train_gen.class_weights}')

    model.fit_generator(generator=train_gen.flow_generator,
                        steps_per_epoch=train_gen.steps_per_epoch,
                        epochs=epochs,
                        verbose=1,
                        class_weight=train_gen.class_weights,
                        callbacks=[checkpoint],
                        validation_data=val_gen.flow_generator,
                        validation_steps=val_gen.steps_per_epoch)