def predict(data_path, pred_path, tempdir, proba_threshold, output_class_names, model_profile, weight_download_region, verbose): """Identify species in a video. This is a command line interface for prediction on camera trap footage. Given a path to camera trap footage, the predict function use a deep learning model to predict the presence or absense of a variety of species of common interest to wildlife researchers working with camera trap data. """ click.echo(f"Using data_path:\t{data_path}") click.echo(f"Using pred_path:\t{pred_path}") # Load the model into manager manager = ModelManager(model_path=default_model_dir, model_class='cnnensemble', proba_threshold=proba_threshold, tempdir=tempdir, output_class_names=output_class_names, model_kwargs=dict( profile=model_profile, download_region=weight_download_region)) # Make predictions, return a DataFrame manager.predict(data_path, pred_path=pred_path, save=True)
def test_predict_invalid_videos(data_dir): """Tests whether invalid videos are correctly skipped.""" tempdir = tempfile.TemporaryDirectory() video_directory = pathlib.Path(tempdir.name) # create invalid (empty) videos for i in range(2): (video_directory / f"invalid{i:02}.mp4").touch() # copy valid videos test_video_path = list(data_dir.glob("*.mp4"))[0] for i in range(2): shutil.copy(test_video_path, video_directory / f"video{i:02}.mp4") manager = ModelManager( '', model_class="cnnensemble", output_class_names=False, model_kwargs={"profile": "fast"}, ) predictions = manager.predict(video_directory) assert predictions.loc[predictions.index.str.contains( "invalid")].isnull().values.all() assert ~predictions.loc[predictions.index.str.contains( "video")].isnull().values.any() tempdir.cleanup()
def test_predict_full(data_dir): manager = ModelManager('', model_class='cnnensemble', output_class_names=False, model_kwargs=dict(profile='full')) result = manager.predict(data_dir, save=True) result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
def test_predict_fast(data_dir): manager = ModelManager('', model_class='cnnensemble', output_class_names=False, model_kwargs=dict(profile='fast')) result = manager.predict(data_dir, save=True) # check that duiker is most likely class (manually verified) assert result.idxmax(axis=1).values[0] == "duiker" result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
def test_create_and_save(sample_model_path, sample_data_path): manager = ModelManager(sample_model_path, model_class='sample') # "predict" (add, multiply), return exact values since no thresh given result = manager.predict(sample_data_path) # 6 + 3 == 9 assert result.iloc[0].added == 9 # 6 * 3 == 18 assert result.iloc[0].multiplied == 18 # 0.3 + 0.1 == 0.4 assert result.iloc[1].added == np.float32(0.3) + np.float32(0.1) # 0.3 * 0.1 == 0.03 assert result.iloc[1].multiplied == np.float32(0.3) * np.float32(0.1) manager.model.save_model() assert manager.model_path.exists()
def test_load_and_predict(sample_model_path, sample_data_path): # load the sample model in the ModelManager manager = ModelManager(sample_model_path, model_class='sample', proba_threshold=0.5) # "predict" (add, multiply), return binary since thresh given preds = manager.predict(sample_data_path) # 6 + 3 == 9 >= 0.5 --> True assert preds.iloc[0].added # 6 * 3 == 18 >= 0.5 --> True assert preds.iloc[0].multiplied # 0.3 + 0.1 == 0.4 <= 0.5 --> False assert not preds.iloc[1].added # 0.3 * 0.1 == 0.03 <= 0.5 --> False assert not preds.iloc[1].multiplied
def test_predict(): data_dir = config.MODEL_DIR / "input" / "raw_test" manager = ModelManager('', model_class='cnnensemble', proba_threshold=0.5) manager.predict(data_dir, save=True)