コード例 #1
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_data_dir_and_filepaths(labels_absolute_path,
                                        labels_relative_path):
    # correct data dir
    config = PredictConfig(data_dir=TEST_VIDEOS_DIR,
                           filepaths=labels_relative_path)
    assert config.data_dir is not None
    assert config.filepaths is not None
    assert config.filepaths.filepath.str.startswith(str(TEST_VIDEOS_DIR)).all()

    # incorrect data dir
    with pytest.raises(ValidationError) as error:
        PredictConfig(data_dir=ASSETS_DIR, filepaths=labels_relative_path)
    assert "None of the video filepaths exist" in error.value.errors(
    )[0]["msg"]
コード例 #2
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_dry_run_and_save(labels_absolute_path, caplog, tmp_path):
    config = PredictConfig(filepaths=labels_absolute_path,
                           dry_run=True,
                           save=True)
    assert (
        "Cannot save when predicting with dry_run=True. Setting save=False and save_dir=None."
        in caplog.text)
    assert not config.save
    assert config.save_dir is None

    config = PredictConfig(filepaths=labels_absolute_path,
                           dry_run=True,
                           save_dir=tmp_path)
    assert not config.save
    assert config.save_dir is None
コード例 #3
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_data_dir_only():
    config = PredictConfig(data_dir=TEST_VIDEOS_DIR)
    assert config.data_dir == TEST_VIDEOS_DIR
    assert isinstance(config.filepaths, pd.DataFrame)
    assert sorted(config.filepaths.filepath.values) == sorted(
        [str(f) for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()])
    assert config.filepaths.columns == ["filepath"]
コード例 #4
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_videos_cannot_be_loaded(tmp_path, labels_absolute_path, caplog):
    files_df = pd.read_csv(labels_absolute_path)
    # create bad files
    for i in np.arange(2):
        bad_file = tmp_path / f"bad_file_{i}.mp4"
        bad_file.touch()
        files_df = files_df.append(
            {
                "filepath": bad_file,
                "label": "gorilla",
                "split": "train"
            },
            ignore_index=True)

    files_df.to_csv(tmp_path / "labels_with_non_loadable_videos.csv")

    config = PredictConfig(filepaths=tmp_path /
                           "labels_with_non_loadable_videos.csv")
    assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
    assert len(config.filepaths) == (len(files_df) - 2)

    config = TrainConfig(labels=tmp_path /
                         "labels_with_non_loadable_videos.csv")
    assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
    assert len(config.labels) == (len(files_df) - 2)
コード例 #5
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_filepaths_with_duplicates(labels_absolute_path, tmp_path,
                                           caplog):
    filepaths = pd.read_csv(labels_absolute_path, usecols=["filepath"])
    # add duplicate filepath
    filepaths.append(filepaths.loc[0]).to_csv(tmp_path /
                                              "filepaths_with_dupe.csv")

    PredictConfig(filepaths=tmp_path / "filepaths_with_dupe.csv")
    assert "Found 1 duplicate row(s) in filepaths csv. Dropping duplicates" in caplog.text
コード例 #6
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_model_cache_dir(labels_absolute_path, tmp_path):
    config = TrainConfig(labels=labels_absolute_path)
    assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba"

    os.environ["MODEL_CACHE_DIR"] = str(tmp_path)
    config = TrainConfig(labels=labels_absolute_path)
    assert config.model_cache_dir == tmp_path

    config = PredictConfig(filepaths=labels_absolute_path,
                           model_cache_dir=tmp_path / "my_cache")
    assert config.model_cache_dir == tmp_path / "my_cache"
コード例 #7
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_save(labels_absolute_path, tmp_path,
                      dummy_trained_model_checkpoint):
    # if save is True, save in current working directory
    config = PredictConfig(filepaths=labels_absolute_path,
                           skip_load_validation=True)
    assert config.save_dir == Path.cwd()

    config = PredictConfig(filepaths=labels_absolute_path,
                           save=False,
                           skip_load_validation=True)
    assert config.save is False
    assert config.save_dir is None

    # if save_dir is specified, set save to True
    config = PredictConfig(
        filepaths=labels_absolute_path,
        save=False,
        save_dir=tmp_path / "my_dir",
        skip_load_validation=True,
    )
    assert config.save is True
    # save dir gets created
    assert (tmp_path / "my_dir").exists()

    # empty save dir does not error
    save_dir = tmp_path / "save_dir"
    save_dir.mkdir()

    config = PredictConfig(filepaths=labels_absolute_path,
                           save_dir=save_dir,
                           skip_load_validation=True)
    assert config.save_dir == save_dir

    # save dir with prediction csv or yaml will error
    for pred_file in [
        (save_dir / "zamba_predictions.csv"),
        (save_dir / "predict_configuration.yaml"),
    ]:
        # just takes one of the two files to raise error
        pred_file.touch()
        with pytest.raises(ValueError) as error:
            PredictConfig(filepaths=labels_absolute_path,
                          save_dir=save_dir,
                          skip_load_validation=True)
        assert (
            f"zamba_predictions.csv and/or predict_configuration.yaml already exist in {save_dir}. If you would like to overwrite, set overwrite=True"
            == error.value.errors()[0]["msg"])
        pred_file.unlink()

    # can overwrite
    pred_file.touch()
    config = PredictConfig(
        filepaths=labels_absolute_path,
        save_dir=save_dir,
        skip_load_validation=True,
        overwrite=True,
    )
    assert config.save_dir == save_dir
コード例 #8
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_filepath_column(tmp_path, labels_absolute_path):
    pd.read_csv(labels_absolute_path).rename(columns={
        "filepath": "video"
    }).to_csv(tmp_path / "bad_filepath_column.csv")
    # predict: filepaths
    with pytest.raises(ValidationError) as error:
        PredictConfig(filepaths=tmp_path / "bad_filepath_column.csv")
    assert "must contain a `filepath` column" in error.value.errors()[0]["msg"]

    # train: labels
    with pytest.raises(ValidationError) as error:
        TrainConfig(labels=tmp_path / "bad_filepath_column.csv")
    assert "must contain `filepath` and `label` columns" in error.value.errors(
    )[0]["msg"]
コード例 #9
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_default_video_loader_config(labels_absolute_path):
    # if no video loader is specified, use default for model
    config = ModelConfig(
        train_config=TrainConfig(labels=labels_absolute_path,
                                 skip_load_validation=True),
        video_loader_config=None,
    )
    assert config.video_loader_config is not None

    config = ModelConfig(
        predict_config=PredictConfig(filepaths=labels_absolute_path,
                                     skip_load_validation=True),
        video_loader_config=None,
    )
    assert config.video_loader_config is not None
コード例 #10
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_one_video_does_not_exist(tmp_path, labels_absolute_path, caplog):
    files_df = pd.read_csv(labels_absolute_path)
    # add a fake file
    files_df = files_df.append(
        {
            "filepath": "fake_file.mp4",
            "label": "gorilla",
            "split": "train"
        },
        ignore_index=True)
    files_df.to_csv(tmp_path / "labels_with_fake_video.csv")

    config = PredictConfig(filepaths=tmp_path / "labels_with_fake_video.csv")
    assert "Skipping 1 file(s) that could not be found" in caplog.text
    # one fewer file than in original list since bad file is skipped
    assert len(config.filepaths) == (len(files_df) - 1)

    config = TrainConfig(labels=tmp_path / "labels_with_fake_video.csv")
    assert "Skipping 1 file(s) that could not be found" in caplog.text
    assert len(config.labels) == (len(files_df) - 1)
コード例 #11
0
ファイル: test_config.py プロジェクト: drivendataorg/zamba
def test_predict_filepaths_only(labels_absolute_path):
    config = PredictConfig(filepaths=labels_absolute_path)
    assert config.filepaths is not None
コード例 #12
0
def predict_metadata(filepaths) -> pd.DataFrame:
    return PredictConfig(filepaths=filepaths).filepaths
コード例 #13
0
def predict(
    data_dir: Path = typer.Option(None,
                                  exists=True,
                                  help="Path to folder containing videos."),
    filepaths: Path = typer.Option(
        None,
        exists=True,
        help="Path to csv containing `filepath` column with videos."),
    model: ModelEnum = typer.Option(
        "time_distributed",
        help=
        "Model to use for inference. Model will be superseded by checkpoint if provided.",
    ),
    checkpoint: Path = typer.Option(
        None,
        exists=True,
        help=
        "Model checkpoint path to use for inference. If provided, model is not required.",
    ),
    gpus: int = typer.Option(
        None,
        help=
        "Number of GPUs to use for inference. If not specifiied, will use all GPUs found on machine.",
    ),
    batch_size: int = typer.Option(None,
                                   help="Batch size to use for training."),
    save: bool = typer.Option(
        None,
        help=
        "Whether to save out predictions. If you want to specify the output directory, use save_dir instead.",
    ),
    save_dir: Path = typer.Option(
        None,
        help=
        "An optional directory in which to save the model predictions and configuration yaml. "
        "Defaults to the current working directory if save is True.",
    ),
    dry_run: bool = typer.Option(
        None, help="Runs one batch of inference to check for bugs."),
    config: Path = typer.Option(
        None,
        exists=True,
        help=
        "Specify options using yaml configuration file instead of through command line options.",
    ),
    proba_threshold: float = typer.Option(
        None,
        help=
        "Probability threshold for classification between 0 and 1. If specified binary predictions "
        "are returned with 1 being greater than the threshold, 0 being less than or equal to. If not "
        "specified, probabilities between 0 and 1 are returned.",
    ),
    output_class_names: bool = typer.Option(
        None,
        help=
        "If True, we just return a video and the name of the most likely class. If False, "
        "we return a probability or indicator (depending on --proba_threshold) for every "
        "possible class.",
    ),
    num_workers: int = typer.Option(
        None,
        help="Number of subprocesses to use for data loading.",
    ),
    weight_download_region: RegionEnum = typer.Option(
        None, help="Server region for downloading weights."),
    skip_load_validation: bool = typer.Option(
        None,
        help=
        "Skip check that verifies all videos can be loaded prior to inference. Only use if you're very confident all your videos can be loaded.",
    ),
    overwrite: bool = typer.Option(
        None,
        "--overwrite",
        "-o",
        help="Overwrite outputs in the save directory if they exist."),
    yes: bool = typer.Option(
        False,
        "--yes",
        "-y",
        help=
        "Skip confirmation of configuration and proceed right to prediction.",
    ),
):
    """Identify species in a video.

    This is a command line interface for prediction on camera trap footage. Given a path to camera
    trap footage, the predict function use a deep learning model to predict the presence or absense of
    a variety of species of common interest to wildlife researchers working with camera trap data.

    If an argument is specified in both the command line and in a yaml file, the command line input will take precedence.
    """
    if config is not None:
        with config.open() as f:
            config_dict = yaml.safe_load(f)
        config_file = config
    else:
        with (MODELS_DIRECTORY / f"{model.value}/config.yaml").open() as f:
            config_dict = yaml.safe_load(f)
        config_file = None

    if "video_loader_config" in config_dict.keys():
        video_loader_config = VideoLoaderConfig(
            **config_dict["video_loader_config"])
    else:
        video_loader_config = None

    predict_dict = config_dict["predict_config"]

    # override if any command line arguments are passed
    if data_dir is not None:
        predict_dict["data_dir"] = data_dir

    if filepaths is not None:
        predict_dict["filepaths"] = filepaths

    if model != "time_distributed":
        predict_dict["model_name"] = model

    if checkpoint is not None:
        predict_dict["checkpoint"] = checkpoint

    if batch_size is not None:
        predict_dict["batch_size"] = batch_size

    if gpus is not None:
        predict_dict["gpus"] = gpus

    if dry_run is not None:
        predict_dict["dry_run"] = dry_run

    if save is not None:
        predict_dict["save"] = save

    # save_dir takes precedence over save
    if save_dir is not None:
        predict_dict["save_dir"] = save_dir

    if proba_threshold is not None:
        predict_dict["proba_threshold"] = proba_threshold

    if output_class_names is not None:
        predict_dict["output_class_names"] = output_class_names

    if num_workers is not None:
        predict_dict["num_workers"] = num_workers

    if weight_download_region is not None:
        predict_dict["weight_download_region"] = weight_download_region

    if skip_load_validation is not None:
        predict_dict["skip_load_validation"] = skip_load_validation

    if overwrite is not None:
        predict_dict["overwrite"] = overwrite

    try:
        manager = ModelManager(
            ModelConfig(
                video_loader_config=video_loader_config,
                predict_config=PredictConfig(**predict_dict),
            ))
    except ValidationError as ex:
        logger.error("Invalid configuration.")
        raise typer.Exit(ex)

    config = manager.config

    msg = f"""The following configuration will be used for inference:

    Config file: {config_file}
    Data directory: {data_dir if data_dir is not None else config_dict["predict_config"].get("data_dir")}
    Filepath csv: {filepaths if filepaths is not None else config_dict["predict_config"].get("filepaths")}
    Model: {config.predict_config.model_name}
    Checkpoint: {checkpoint if checkpoint is not None else config_dict["predict_config"].get("checkpoint")}
    Batch size: {config.predict_config.batch_size}
    Number of workers: {config.predict_config.num_workers}
    GPUs: {config.predict_config.gpus}
    Dry run: {config.predict_config.dry_run}
    Save directory: {config.predict_config.save_dir}
    Proba threshold: {config.predict_config.proba_threshold}
    Output class names: {config.predict_config.output_class_names}
    Weight download region: {config.predict_config.weight_download_region}
    """

    if yes:
        typer.echo(
            f"{msg}\n\nSkipping confirmation and proceeding to prediction.")
    else:
        yes = typer.confirm(
            f"{msg}\n\nIs this correct?",
            abort=False,
            default=True,
        )

    if yes:
        # kick off prediction
        manager.predict()
コード例 #14
0
def predict_model(
    predict_config: PredictConfig,
    video_loader_config: VideoLoaderConfig = None,
):
    """Predicts from a model and writes out predictions to a csv.

    Args:
        predict_config (PredictConfig): Pydantic config for performing inference.
        video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
            If None, will use default for model specified in PredictConfig.
    """
    # get default VLC for model if not specified
    if video_loader_config is None:
        video_loader_config = ModelConfig(
            predict_config=predict_config, video_loader_config=video_loader_config
        ).video_loader_config

    # set up model
    model = instantiate_model(
        checkpoint=predict_config.checkpoint,
        weight_download_region=predict_config.weight_download_region,
        model_cache_dir=predict_config.model_cache_dir,
        scheduler_config=None,
        labels=None,
    )

    data_module = ZambaDataModule(
        video_loader_config=video_loader_config,
        transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
        predict_metadata=predict_config.filepaths,
        batch_size=predict_config.batch_size,
        num_workers=predict_config.num_workers,
    )

    validate_species(model, data_module)

    if video_loader_config.cache_dir is None:
        logger.info("No cache dir is specified. Videos will not be cached.")
    else:
        logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")

    trainer = pl.Trainer(
        gpus=predict_config.gpus, logger=False, fast_dev_run=predict_config.dry_run
    )

    configuration = {
        "model_class": model.model_class,
        "species": model.species,
        "predict_config": json.loads(predict_config.json(exclude={"filepaths"})),
        "inference_start_time": datetime.utcnow().isoformat(),
        "video_loader_config": json.loads(video_loader_config.json()),
    }

    if predict_config.save is not False:

        config_path = predict_config.save_dir / "predict_configuration.yaml"
        logger.info(f"Writing out full configuration to {config_path}.")
        with config_path.open("w") as fp:
            yaml.dump(configuration, fp)

    dataloader = data_module.predict_dataloader()
    logger.info("Starting prediction...")
    probas = trainer.predict(model=model, dataloaders=dataloader)

    df = pd.DataFrame(
        np.vstack(probas), columns=model.species, index=dataloader.dataset.original_indices
    )

    # change output format if specified
    if predict_config.proba_threshold is not None:
        df = (df > predict_config.proba_threshold).astype(int)

    elif predict_config.output_class_names:
        df = df.idxmax(axis=1)

    else:  # round to a useful number of places
        df = df.round(5)

    if predict_config.save is not False:

        preds_path = predict_config.save_dir / "zamba_predictions.csv"
        logger.info(f"Saving out predictions to {preds_path}.")
        with preds_path.open("w") as fp:
            df.to_csv(fp, index=True)

    return df