示例#1
0
文件: cli.py 项目: qyj0731/data_work
def predict(data_path, pred_path, tempdir, proba_threshold, output_class_names,
            model_profile, weight_download_region, verbose):
    """Identify species in a video.

      This is a command line interface for prediction on camera trap footage. Given a path to camera trap footage,
      the predict function use a deep learning model to predict the presence or absense of a variety of species of
      common interest to wildlife researchers working with camera trap data.

    """

    click.echo(f"Using data_path:\t{data_path}")
    click.echo(f"Using pred_path:\t{pred_path}")

    # Load the model into manager
    manager = ModelManager(model_path=default_model_dir,
                           model_class='cnnensemble',
                           proba_threshold=proba_threshold,
                           tempdir=tempdir,
                           output_class_names=output_class_names,
                           model_kwargs=dict(
                               profile=model_profile,
                               download_region=weight_download_region))

    # Make predictions, return a DataFrame
    manager.predict(data_path, pred_path=pred_path, save=True)
示例#2
0
def test_predict_invalid_videos(data_dir):
    """Tests whether invalid videos are correctly skipped."""
    tempdir = tempfile.TemporaryDirectory()
    video_directory = pathlib.Path(tempdir.name)

    # create invalid (empty) videos
    for i in range(2):
        (video_directory / f"invalid{i:02}.mp4").touch()

    # copy valid videos
    test_video_path = list(data_dir.glob("*.mp4"))[0]
    for i in range(2):
        shutil.copy(test_video_path, video_directory / f"video{i:02}.mp4")

    manager = ModelManager(
        '',
        model_class="cnnensemble",
        output_class_names=False,
        model_kwargs={"profile": "fast"},
    )
    predictions = manager.predict(video_directory)
    assert predictions.loc[predictions.index.str.contains(
        "invalid")].isnull().values.all()

    assert ~predictions.loc[predictions.index.str.contains(
        "video")].isnull().values.any()

    tempdir.cleanup()
示例#3
0
def test_predict_full(data_dir):
    manager = ModelManager('',
                           model_class='cnnensemble',
                           output_class_names=False,
                           model_kwargs=dict(profile='full'))
    result = manager.predict(data_dir, save=True)
    result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
示例#4
0
def test_predict_fast(data_dir):
    manager = ModelManager('',
                           model_class='cnnensemble',
                           output_class_names=False,
                           model_kwargs=dict(profile='fast'))
    result = manager.predict(data_dir, save=True)

    # check that duiker is most likely class (manually verified)
    assert result.idxmax(axis=1).values[0] == "duiker"

    result.to_csv(str(config.MODEL_DIR / 'output' / 'test_prediction.csv'))
示例#5
0
def test_load_data(data_dir):
    manager = ModelManager(
        '',
        model_class='cnnensemble',
        output_class_names=False,
        model_kwargs=dict(profile='fast'),
    )
    input_paths = manager.model.load_data(data_dir)
    assert len(input_paths) > 0
示例#6
0
文件: cli.py 项目: qyj0731/data_work
def train(data_path, labels, site_aware, tempdir):
    """ Retrain network from scratch.

        Train the weights from scratch using
        the provided data_path and labels.
    """
    click.echo(f"Using data_path:\t{data_path}")
    click.echo(f"Using labels:\t{labels}")
    click.echo(f"Using tempdir:\t{tempdir}")

    # Load the model into manager
    manager = ModelManager(model_path=default_model_dir,
                           model_class='cnnensemble',
                           tempdir=tempdir,
                           model_kwargs=dict(download_weights=False,
                                             site_aware=site_aware,
                                             raw_video_dir=data_path,
                                             labels_path=labels))

    # Make predictions, return a DataFrame
    manager.train()
示例#7
0
def test_create_and_save(sample_model_path, sample_data_path):

    manager = ModelManager(sample_model_path, model_class='sample')

    # "predict" (add, multiply), return exact values since no thresh given
    result = manager.predict(sample_data_path)

    # 6 + 3 == 9
    assert result.iloc[0].added == 9

    # 6 * 3 == 18
    assert result.iloc[0].multiplied == 18

    # 0.3 + 0.1 == 0.4
    assert result.iloc[1].added == np.float32(0.3) + np.float32(0.1)

    # 0.3 * 0.1 == 0.03
    assert result.iloc[1].multiplied == np.float32(0.3) * np.float32(0.1)

    manager.model.save_model()
    assert manager.model_path.exists()
示例#8
0
def test_load_and_predict(sample_model_path, sample_data_path):

    # load the sample model in the ModelManager
    manager = ModelManager(sample_model_path,
                           model_class='sample',
                           proba_threshold=0.5)

    # "predict" (add, multiply), return binary since thresh given
    preds = manager.predict(sample_data_path)

    # 6 + 3 == 9 >= 0.5 --> True
    assert preds.iloc[0].added

    # 6 * 3 == 18 >= 0.5 --> True
    assert preds.iloc[0].multiplied

    # 0.3 + 0.1 == 0.4 <= 0.5 --> False
    assert not preds.iloc[1].added

    # 0.3 * 0.1 == 0.03 <= 0.5 --> False
    assert not preds.iloc[1].multiplied
示例#9
0
def test_train():
    manager = ModelManager(model_class='cnnensemble',
                           verbose=True,
                           model_kwargs=dict(download_weights=False))
    manager.train(config)
示例#10
0
def test_predict():

    data_dir = config.MODEL_DIR / "input" / "raw_test"

    manager = ModelManager('', model_class='cnnensemble', proba_threshold=0.5)
    manager.predict(data_dir, save=True)