Beispiel #1
0
def test_validate_scheduler(labels_absolute_path):
    # None gets transformed into SchedulerConfig
    config = TrainConfig(labels=labels_absolute_path,
                         scheduler_config=None,
                         skip_load_validation=True)
    assert config.scheduler_config == SchedulerConfig(scheduler=None,
                                                      scheduler_params=None)

    # default is valid
    config = TrainConfig(labels=labels_absolute_path,
                         scheduler_config="default",
                         skip_load_validation=True)
    assert config.scheduler_config == "default"

    # other strings are not
    with pytest.raises(ValueError) as error:
        TrainConfig(labels=labels_absolute_path,
                    scheduler_config="StepLR",
                    skip_load_validation=True)
    assert ("Scheduler can either be 'default', None, or a SchedulerConfig." ==
            error.value.errors()[0]["msg"])

    # custom scheduler
    config = TrainConfig(
        labels=labels_absolute_path,
        scheduler_config=SchedulerConfig(scheduler="StepLR",
                                         scheduler_params={"gamma": 0.2}),
        skip_load_validation=True,
    )
    assert config.scheduler_config == SchedulerConfig(
        scheduler="StepLR", scheduler_params={"gamma": 0.2})
Beispiel #2
0
def test_labels_no_splits(labels_no_splits, tmp_path):
    # ensure species are allocated to both sets
    labels_four_videos = pd.read_csv(labels_no_splits).head(4)
    labels_four_videos["label"] = ["gorilla"] * 2 + ["elephant"] * 2
    _ = TrainConfig(
        data_dir=TEST_VIDEOS_DIR,
        labels=labels_four_videos,
        save_dir=tmp_path,
        split_proportions=dict(train=1, val=1, holdout=0),
    )

    assert (pd.read_csv(tmp_path / "splits.csv").split.values == [
        "train", "val", "train", "val"
    ]).all()

    # remove the first row which puts antelope_duiker at 2 instead of 3
    labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :]
    with pytest.raises(ValueError) as error:
        TrainConfig(
            data_dir=TEST_VIDEOS_DIR,
            labels=labels_with_too_few_videos,
            save_dir=tmp_path,
        )
    assert (
        "Not all species have enough videos to allocate into the following splits: train, val, holdout. A minimum of 3 videos per label is required. Found the following counts: {'antelope_duiker': 2}. Either remove these labels or add more videos."
    ) == error.value.errors()[0]["msg"]
Beispiel #3
0
def test_model_cache_dir(labels_absolute_path, tmp_path):
    config = TrainConfig(labels=labels_absolute_path)
    assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba"

    os.environ["MODEL_CACHE_DIR"] = str(tmp_path)
    config = TrainConfig(labels=labels_absolute_path)
    assert config.model_cache_dir == tmp_path

    config = PredictConfig(filepaths=labels_absolute_path,
                           model_cache_dir=tmp_path / "my_cache")
    assert config.model_cache_dir == tmp_path / "my_cache"
Beispiel #4
0
def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog):
    # check dry_run is True sets skip_load_validation to True
    config = TrainConfig(labels=labels_absolute_path,
                         dry_run=True,
                         skip_load_validation=False)
    assert config.skip_load_validation
    assert "Turning off video loading check since dry_run=True." in caplog.text

    # if dry run is False, skip_load_validation is unchanged
    config = TrainConfig(labels=labels_absolute_path,
                         dry_run=False,
                         skip_load_validation=False)
    assert not config.skip_load_validation
Beispiel #5
0
def test_from_scratch(labels_absolute_path):
    config = TrainConfig(labels=labels_absolute_path,
                         from_scratch=True,
                         checkpoint=None)
    assert config.model_name == "time_distributed"
    assert config.from_scratch
    assert config.checkpoint is None

    with pytest.raises(ValueError) as error:
        TrainConfig(labels=labels_absolute_path,
                    from_scratch=True,
                    model_name=None)
    assert "If from_scratch=True, model_name cannot be None." == error.value.errors(
    )[0]["msg"]
Beispiel #6
0
def test_labels_with_all_null_species(labels_absolute_path):
    labels = pd.read_csv(labels_absolute_path)
    labels["label"] = np.nan
    with pytest.raises(ValueError) as error:
        TrainConfig(labels=labels)
    assert "Species cannot be null for all videos." == error.value.errors(
    )[0]["msg"]
Beispiel #7
0
def test_videos_cannot_be_loaded(tmp_path, labels_absolute_path, caplog):
    files_df = pd.read_csv(labels_absolute_path)
    # create bad files
    for i in np.arange(2):
        bad_file = tmp_path / f"bad_file_{i}.mp4"
        bad_file.touch()
        files_df = files_df.append(
            {
                "filepath": bad_file,
                "label": "gorilla",
                "split": "train"
            },
            ignore_index=True)

    files_df.to_csv(tmp_path / "labels_with_non_loadable_videos.csv")

    config = PredictConfig(filepaths=tmp_path /
                           "labels_with_non_loadable_videos.csv")
    assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
    assert len(config.filepaths) == (len(files_df) - 2)

    config = TrainConfig(labels=tmp_path /
                         "labels_with_non_loadable_videos.csv")
    assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
    assert len(config.labels) == (len(files_df) - 2)
Beispiel #8
0
def test_labels_with_invalid_split(labels_absolute_path):
    labels = pd.read_csv(labels_absolute_path)
    labels.loc[0, "split"] = "test"
    with pytest.raises(ValueError) as error:
        TrainConfig(labels=labels)
    assert (
        "Found the following invalid values for `split`: {'test'}. `split` can only contain `train`, `val`, or `holdout.`"
    ) == error.value.errors()[0]["msg"]
Beispiel #9
0
def test_label_column(tmp_path, labels_absolute_path):
    pd.read_csv(labels_absolute_path).rename(columns={
        "label": "animal"
    }).to_csv(tmp_path / "bad_label_column.csv")
    with pytest.raises(ValidationError) as error:
        TrainConfig(labels=tmp_path / "bad_label_column.csv")
    assert "must contain `filepath` and `label` columns" in error.value.errors(
    )[0]["msg"]
Beispiel #10
0
def test_checkpoint_sets_model_to_none(labels_absolute_path,
                                       dummy_trained_model_checkpoint):
    config = TrainConfig(
        labels=labels_absolute_path,
        checkpoint=dummy_trained_model_checkpoint,
        skip_load_validation=True,
    )
    assert config.model_name is None
Beispiel #11
0
def test_labels_with_partially_null_split(labels_absolute_path):
    labels = pd.read_csv(labels_absolute_path)
    labels.loc[0, "split"] = np.nan
    with pytest.raises(ValueError) as error:
        TrainConfig(labels=labels)
    assert (
        "Found 1 row(s) with null `split`. Fill in these rows with either `train`, `val`, or `holdout`"
    ) in error.value.errors()[0]["msg"]
Beispiel #12
0
def test_train_data_dir_only():
    with pytest.raises(ValidationError) as error:
        TrainConfig(data_dir=TEST_VIDEOS_DIR)
    # labels is missing
    assert error.value.errors() == [{
        "loc": ("labels", ),
        "msg": "field required",
        "type": "value_error.missing"
    }]
Beispiel #13
0
def test_train_data_dir_and_labels(tmp_path, labels_relative_path,
                                   labels_absolute_path):
    # correct data dir
    config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path)
    assert config.data_dir is not None
    assert config.labels is not None

    # data dir ignored if absolute path provided in filepath
    config = TrainConfig(data_dir=tmp_path, labels=labels_absolute_path)
    assert config.data_dir is not None
    assert config.labels is not None
    assert not config.labels.filepath.str.startswith(str(tmp_path)).any()

    # incorrect data dir with relative filepaths
    with pytest.raises(ValidationError) as error:
        TrainConfig(data_dir=ASSETS_DIR, labels=labels_relative_path)
    assert "None of the video filepaths exist" in error.value.errors(
    )[0]["msg"]
def test_finetune_new_labels(labels_absolute_path, model, tmp_path):
    config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True)
    model = instantiate_model(
        checkpoint=config.checkpoint,
        weight_download_region=config.weight_download_region,
        scheduler_config="default",
        labels=pd.DataFrame([{"filepath": "kangaroo.mp4", "species_kangaroo": 1}]),
        model_cache_dir=tmp_path,
    )
    assert model.species == ["kangaroo"]
Beispiel #15
0
def test_labels_split_proportions(labels_no_splits, tmp_path):
    config = TrainConfig(
        data_dir=TEST_VIDEOS_DIR,
        labels=labels_no_splits,
        split_proportions={
            "a": 3,
            "b": 1
        },
        save_dir=tmp_path,
    )
    assert config.labels.split.value_counts().to_dict() == {"a": 13, "b": 6}
Beispiel #16
0
def test_filepath_column(tmp_path, labels_absolute_path):
    pd.read_csv(labels_absolute_path).rename(columns={
        "filepath": "video"
    }).to_csv(tmp_path / "bad_filepath_column.csv")
    # predict: filepaths
    with pytest.raises(ValidationError) as error:
        PredictConfig(filepaths=tmp_path / "bad_filepath_column.csv")
    assert "must contain a `filepath` column" in error.value.errors()[0]["msg"]

    # train: labels
    with pytest.raises(ValidationError) as error:
        TrainConfig(labels=tmp_path / "bad_filepath_column.csv")
    assert "must contain `filepath` and `label` columns" in error.value.errors(
    )[0]["msg"]
Beispiel #17
0
def test_default_video_loader_config(labels_absolute_path):
    # if no video loader is specified, use default for model
    config = ModelConfig(
        train_config=TrainConfig(labels=labels_absolute_path,
                                 skip_load_validation=True),
        video_loader_config=None,
    )
    assert config.video_loader_config is not None

    config = ModelConfig(
        predict_config=PredictConfig(filepaths=labels_absolute_path,
                                     skip_load_validation=True),
        video_loader_config=None,
    )
    assert config.video_loader_config is not None
Beispiel #18
0
def test_one_video_does_not_exist(tmp_path, labels_absolute_path, caplog):
    files_df = pd.read_csv(labels_absolute_path)
    # add a fake file
    files_df = files_df.append(
        {
            "filepath": "fake_file.mp4",
            "label": "gorilla",
            "split": "train"
        },
        ignore_index=True)
    files_df.to_csv(tmp_path / "labels_with_fake_video.csv")

    config = PredictConfig(filepaths=tmp_path / "labels_with_fake_video.csv")
    assert "Skipping 1 file(s) that could not be found" in caplog.text
    # one fewer file than in original list since bad file is skipped
    assert len(config.filepaths) == (len(files_df) - 1)

    config = TrainConfig(labels=tmp_path / "labels_with_fake_video.csv")
    assert "Skipping 1 file(s) that could not be found" in caplog.text
    assert len(config.labels) == (len(files_df) - 1)
Beispiel #19
0
def train_model(
    train_config: TrainConfig,
    video_loader_config: Optional[VideoLoaderConfig] = None,
):
    """Trains a model.

    Args:
        train_config (TrainConfig): Pydantic config for training.
        video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
            If None, will use default for model specified in TrainConfig.
    """
    # get default VLC for model if not specified
    if video_loader_config is None:
        video_loader_config = ModelConfig(
            train_config=train_config, video_loader_config=video_loader_config
        ).video_loader_config

    # set up model
    model = instantiate_model(
        checkpoint=train_config.checkpoint,
        scheduler_config=train_config.scheduler_config,
        weight_download_region=train_config.weight_download_region,
        model_cache_dir=train_config.model_cache_dir,
        labels=train_config.labels,
        from_scratch=train_config.from_scratch,
        model_name=train_config.model_name,
        predict_all_zamba_species=train_config.predict_all_zamba_species,
    )

    data_module = ZambaDataModule(
        video_loader_config=video_loader_config,
        transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
        train_metadata=train_config.labels,
        batch_size=train_config.batch_size,
        num_workers=train_config.num_workers,
    )

    validate_species(model, data_module)

    train_config.save_dir.mkdir(parents=True, exist_ok=True)

    # add folder version_n that auto increments if we are not overwriting
    tensorboard_version = train_config.save_dir.name if train_config.overwrite else None
    tensorboard_save_dir = (
        train_config.save_dir.parent if train_config.overwrite else train_config.save_dir
    )

    tensorboard_logger = TensorBoardLogger(
        save_dir=tensorboard_save_dir,
        name=None,
        version=tensorboard_version,
        default_hp_metric=False,
    )

    logging_and_save_dir = (
        tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir
    )

    model_checkpoint = ModelCheckpoint(
        dirpath=logging_and_save_dir,
        filename=train_config.model_name,
        monitor=train_config.early_stopping_config.monitor
        if train_config.early_stopping_config is not None
        else None,
        mode=train_config.early_stopping_config.mode
        if train_config.early_stopping_config is not None
        else "min",
    )

    callbacks = [model_checkpoint]

    if train_config.early_stopping_config is not None:
        callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict()))

    if train_config.backbone_finetune_config is not None:
        callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict()))

    trainer = pl.Trainer(
        gpus=train_config.gpus,
        max_epochs=train_config.max_epochs,
        auto_lr_find=train_config.auto_lr_find,
        logger=tensorboard_logger,
        callbacks=callbacks,
        fast_dev_run=train_config.dry_run,
        accelerator="ddp" if data_module.multiprocessing_context is not None else None,
        plugins=DDPPlugin(find_unused_parameters=False)
        if data_module.multiprocessing_context is not None
        else None,
    )

    if video_loader_config.cache_dir is None:
        logger.info("No cache dir is specified. Videos will not be cached.")
    else:
        logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")

    if train_config.auto_lr_find:
        logger.info("Finding best learning rate.")
        trainer.tune(model, data_module)

    try:
        git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
    except git.exc.InvalidGitRepositoryError:
        git_hash = None

    configuration = {
        "git_hash": git_hash,
        "model_class": model.model_class,
        "species": model.species,
        "starting_learning_rate": model.lr,
        "train_config": json.loads(train_config.json(exclude={"labels"})),
        "training_start_time": datetime.utcnow().isoformat(),
        "video_loader_config": json.loads(video_loader_config.json()),
    }

    if not train_config.dry_run:
        config_path = Path(logging_and_save_dir) / "train_configuration.yaml"
        config_path.parent.mkdir(exist_ok=True, parents=True)
        logger.info(f"Writing out full configuration to {config_path}.")
        with config_path.open("w") as fp:
            yaml.dump(configuration, fp)

    logger.info("Starting training...")
    trainer.fit(model, data_module)

    if not train_config.dry_run:
        if trainer.datamodule.test_dataloader() is not None:
            logger.info("Calculating metrics on holdout set.")
            test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0]
            with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp:
                json.dump(test_metrics, fp, indent=2)

        if trainer.datamodule.val_dataloader() is not None:
            logger.info("Calculating metrics on validation set.")
            val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0]
            with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp:
                json.dump(val_metrics, fp, indent=2)

    return trainer
Beispiel #20
0
def test_train_labels_only(labels_absolute_path):
    config = TrainConfig(labels=labels_absolute_path)
    assert config.labels is not None
Beispiel #21
0
def train_metadata(labels_absolute_path) -> pd.DataFrame:
    return TrainConfig(labels=labels_absolute_path).labels
Beispiel #22
0
def time_distributed_checkpoint(labels_absolute_path) -> os.PathLike:
    return TrainConfig(labels=labels_absolute_path, model_name="time_distributed").checkpoint
def test_resume_subset_labels(labels_absolute_path, model, tmp_path):
    config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True)
    model = instantiate_model(
        checkpoint=config.checkpoint,
        weight_download_region=config.weight_download_region,
        scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None),
        # pick species that is present in all models
        labels=pd.DataFrame([{"filepath": "bird.mp4", "species_bird": 1}]),
        model_cache_dir=tmp_path,
    )
    assert model.hparams["scheduler"] == "StepLR"

    if config.model_name == "european":
        assert model.species == [
            "bird",
            "blank",
            "domestic_cat",
            "european_badger",
            "european_beaver",
            "european_hare",
            "european_roe_deer",
            "north_american_raccoon",
            "red_fox",
            "weasel",
            "wild_boar",
        ]
    else:
        assert model.species == [
            "aardvark",
            "antelope_duiker",
            "badger",
            "bat",
            "bird",
            "blank",
            "cattle",
            "cheetah",
            "chimpanzee_bonobo",
            "civet_genet",
            "elephant",
            "equid",
            "forest_buffalo",
            "fox",
            "giraffe",
            "gorilla",
            "hare_rabbit",
            "hippopotamus",
            "hog",
            "human",
            "hyena",
            "large_flightless_bird",
            "leopard",
            "lion",
            "mongoose",
            "monkey_prosimian",
            "pangolin",
            "porcupine",
            "reptile",
            "rodent",
            "small_cat",
            "wild_dog_jackal",
        ]
Beispiel #24
0
def test_labels_with_partially_null_species(labels_absolute_path, caplog):
    labels = pd.read_csv(labels_absolute_path)
    labels.loc[0, "label"] = np.nan
    TrainConfig(labels=labels)
    assert "Found 1 filepath(s) with no label. Will skip." in caplog.text
Beispiel #25
0
def test_labels_with_all_null_split(labels_absolute_path, caplog):
    labels = pd.read_csv(labels_absolute_path)
    labels["split"] = np.nan
    TrainConfig(labels=labels)
    assert "Split column is entirely null. Will generate splits automatically" in caplog.text
Beispiel #26
0
def train(
    data_dir: Path = typer.Option(None,
                                  exists=True,
                                  help="Path to folder containing videos."),
    labels: Path = typer.Option(None,
                                exists=True,
                                help="Path to csv containing video labels."),
    model: ModelEnum = typer.Option(
        "time_distributed",
        help=
        "Model to train. Model will be superseded by checkpoint if provided.",
    ),
    checkpoint: Path = typer.Option(
        None,
        exists=True,
        help=
        "Model checkpoint path to use for training. If provided, model is not required.",
    ),
    config: Path = typer.Option(
        None,
        exists=True,
        help=
        "Specify options using yaml configuration file instead of through command line options.",
    ),
    batch_size: int = typer.Option(None,
                                   help="Batch size to use for training."),
    gpus: int = typer.Option(
        None,
        help=
        "Number of GPUs to use for training. If not specifiied, will use all GPUs found on machine.",
    ),
    dry_run: bool = typer.Option(
        None,
        help="Runs one batch of train and validation to check for bugs.",
    ),
    save_dir: Path = typer.Option(
        None,
        help=
        "An optional directory in which to save the model checkpoint and configuration file. If not specified, will save to a `version_n` folder in your working directory.",
    ),
    num_workers: int = typer.Option(
        None,
        help="Number of subprocesses to use for data loading.",
    ),
    weight_download_region: RegionEnum = typer.Option(
        None, help="Server region for downloading weights."),
    skip_load_validation: bool = typer.Option(
        None,
        help=
        "Skip check that verifies all videos can be loaded prior to training. Only use if you're very confident all your videos can be loaded.",
    ),
    yes: bool = typer.Option(
        False,
        "--yes",
        "-y",
        help=
        "Skip confirmation of configuration and proceed right to training.",
    ),
):
    """Train a model on your labeled data.

    If an argument is specified in both the command line and in a yaml file, the command line input will take precedence.
    """
    if config is not None:
        with config.open() as f:
            config_dict = yaml.safe_load(f)
        config_file = config
    else:
        with (MODELS_DIRECTORY / f"{model.value}/config.yaml").open() as f:
            config_dict = yaml.safe_load(f)
        config_file = None

    if "video_loader_config" in config_dict.keys():
        video_loader_config = VideoLoaderConfig(
            **config_dict["video_loader_config"])
    else:
        video_loader_config = None

    train_dict = config_dict["train_config"]

    # override if any command line arguments are passed
    if data_dir is not None:
        train_dict["data_dir"] = data_dir

    if labels is not None:
        train_dict["labels"] = labels

    if model != "time_distributed":
        train_dict["model_name"] = model

    if checkpoint is not None:
        train_dict["checkpoint"] = checkpoint

    if batch_size is not None:
        train_dict["batch_size"] = batch_size

    if gpus is not None:
        train_dict["gpus"] = gpus

    if dry_run is not None:
        train_dict["dry_run"] = dry_run

    if save_dir is not None:
        train_dict["save_dir"] = save_dir

    if num_workers is not None:
        train_dict["num_workers"] = num_workers

    if weight_download_region is not None:
        train_dict["weight_download_region"] = weight_download_region

    if skip_load_validation is not None:
        train_dict["skip_load_validation"] = skip_load_validation

    try:
        manager = ModelManager(
            ModelConfig(
                video_loader_config=video_loader_config,
                train_config=TrainConfig(**train_dict),
            ))
    except ValidationError as ex:
        logger.error("Invalid configuration.")
        raise typer.Exit(ex)

    config = manager.config

    # get species to confirm
    spacer = "\n\t- "
    species = spacer + spacer.join(
        sorted([
            c.split("species_", 1)[1]
            for c in config.train_config.labels.filter(regex="species").columns
        ]))

    msg = f"""The following configuration will be used for training:

    Config file: {config_file}
    Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_dir")}
    Labels csv: {labels if labels is not None else config_dict["train_config"].get("labels")}
    Species: {species}
    Model name: {config.train_config.model_name}
    Checkpoint: {checkpoint if checkpoint is not None else config_dict["train_config"].get("checkpoint")}
    Batch size: {config.train_config.batch_size}
    Number of workers: {config.train_config.num_workers}
    GPUs: {config.train_config.gpus}
    Dry run: {config.train_config.dry_run}
    Save directory: {config.train_config.save_dir}
    Weight download region: {config.train_config.weight_download_region}
    """

    if yes:
        typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to train.")
    else:
        yes = typer.confirm(
            f"{msg}\n\nIs this correct?",
            abort=False,
            default=True,
        )

    if yes:
        # kick off training
        manager.train()