Beispiel #1
0
def predict(data_path, pred_path, tempdir, proba_threshold, output_class_names,
            model_profile, weight_download_region, verbose):
    """Identify species in a video.

      This is a command line interface for prediction on camera trap footage. Given a path to camera trap footage,
      the predict function use a deep learning model to predict the presence or absense of a variety of species of
      common interest to wildlife researchers working with camera trap data.

    """

    click.echo(f"Using data_path:\t{data_path}")
    click.echo(f"Using pred_path:\t{pred_path}")

    # Load the model into manager
    manager = ModelManager(model_path=default_model_dir,
                           model_class='cnnensemble',
                           proba_threshold=proba_threshold,
                           tempdir=tempdir,
                           output_class_names=output_class_names,
                           model_kwargs=dict(
                               profile=model_profile,
                               download_region=weight_download_region))

    # Make predictions, return a DataFrame
    manager.predict(data_path, pred_path=pred_path, save=True)
Beispiel #2
0
def test_predict_invalid_videos(data_dir):
    """Tests whether invalid videos are correctly skipped."""
    tempdir = tempfile.TemporaryDirectory()
    video_directory = pathlib.Path(tempdir.name)

    # create invalid (empty) videos
    for i in range(2):
        (video_directory / f"invalid{i:02}.mp4").touch()

    # copy valid videos
    test_video_path = list(data_dir.glob("*.mp4"))[0]
    for i in range(2):
        shutil.copy(test_video_path, video_directory / f"video{i:02}.mp4")

    manager = ModelManager(
        '',
        model_class="cnnensemble",
        output_class_names=False,
        model_kwargs={"profile": "fast"},
    )
    predictions = manager.predict(video_directory)
    assert predictions.loc[predictions.index.str.contains(
        "invalid")].isnull().values.all()

    assert ~predictions.loc[predictions.index.str.contains(
        "video")].isnull().values.any()

    tempdir.cleanup()
Beispiel #3
0
def test_predict_full(data_dir):
    manager = ModelManager('',
                           model_class='cnnensemble',
                           output_class_names=False,
                           model_kwargs=dict(profile='full'))
    result = manager.predict(data_dir, save=True)
    result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
Beispiel #4
0
def test_predict_fast(data_dir):
    manager = ModelManager('',
                           model_class='cnnensemble',
                           output_class_names=False,
                           model_kwargs=dict(profile='fast'))
    result = manager.predict(data_dir, save=True)

    # check that duiker is most likely class (manually verified)
    assert result.idxmax(axis=1).values[0] == "duiker"

    result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
Beispiel #5
0
def test_create_and_save(sample_model_path, sample_data_path):

    manager = ModelManager(sample_model_path, model_class='sample')

    # "predict" (add, multiply), return exact values since no thresh given
    result = manager.predict(sample_data_path)

    # 6 + 3 == 9
    assert result.iloc[0].added == 9

    # 6 * 3 == 18
    assert result.iloc[0].multiplied == 18

    # 0.3 + 0.1 == 0.4
    assert result.iloc[1].added == np.float32(0.3) + np.float32(0.1)

    # 0.3 * 0.1 == 0.03
    assert result.iloc[1].multiplied == np.float32(0.3) * np.float32(0.1)

    manager.model.save_model()
    assert manager.model_path.exists()
Beispiel #6
0
def test_load_and_predict(sample_model_path, sample_data_path):

    # load the sample model in the ModelManager
    manager = ModelManager(sample_model_path,
                           model_class='sample',
                           proba_threshold=0.5)

    # "predict" (add, multiply), return binary since thresh given
    preds = manager.predict(sample_data_path)

    # 6 + 3 == 9 >= 0.5 --> True
    assert preds.iloc[0].added

    # 6 * 3 == 18 >= 0.5 --> True
    assert preds.iloc[0].multiplied

    # 0.3 + 0.1 == 0.4 <= 0.5 --> False
    assert not preds.iloc[1].added

    # 0.3 * 0.1 == 0.03 <= 0.5 --> False
    assert not preds.iloc[1].multiplied
def test_predict():

    data_dir = config.MODEL_DIR / "input" / "raw_test"

    manager = ModelManager('', model_class='cnnensemble', proba_threshold=0.5)
    manager.predict(data_dir, save=True)