def test_validate_scheduler(labels_absolute_path): # None gets transformed into SchedulerConfig config = TrainConfig(labels=labels_absolute_path, scheduler_config=None, skip_load_validation=True) assert config.scheduler_config == SchedulerConfig(scheduler=None, scheduler_params=None) # default is valid config = TrainConfig(labels=labels_absolute_path, scheduler_config="default", skip_load_validation=True) assert config.scheduler_config == "default" # other strings are not with pytest.raises(ValueError) as error: TrainConfig(labels=labels_absolute_path, scheduler_config="StepLR", skip_load_validation=True) assert ("Scheduler can either be 'default', None, or a SchedulerConfig." == error.value.errors()[0]["msg"]) # custom scheduler config = TrainConfig( labels=labels_absolute_path, scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params={"gamma": 0.2}), skip_load_validation=True, ) assert config.scheduler_config == SchedulerConfig( scheduler="StepLR", scheduler_params={"gamma": 0.2})
def test_labels_no_splits(labels_no_splits, tmp_path): # ensure species are allocated to both sets labels_four_videos = pd.read_csv(labels_no_splits).head(4) labels_four_videos["label"] = ["gorilla"] * 2 + ["elephant"] * 2 _ = TrainConfig( data_dir=TEST_VIDEOS_DIR, labels=labels_four_videos, save_dir=tmp_path, split_proportions=dict(train=1, val=1, holdout=0), ) assert (pd.read_csv(tmp_path / "splits.csv").split.values == [ "train", "val", "train", "val" ]).all() # remove the first row which puts antelope_duiker at 2 instead of 3 labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :] with pytest.raises(ValueError) as error: TrainConfig( data_dir=TEST_VIDEOS_DIR, labels=labels_with_too_few_videos, save_dir=tmp_path, ) assert ( "Not all species have enough videos to allocate into the following splits: train, val, holdout. A minimum of 3 videos per label is required. Found the following counts: {'antelope_duiker': 2}. Either remove these labels or add more videos." ) == error.value.errors()[0]["msg"]
def test_model_cache_dir(labels_absolute_path, tmp_path): config = TrainConfig(labels=labels_absolute_path) assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba" os.environ["MODEL_CACHE_DIR"] = str(tmp_path) config = TrainConfig(labels=labels_absolute_path) assert config.model_cache_dir == tmp_path config = PredictConfig(filepaths=labels_absolute_path, model_cache_dir=tmp_path / "my_cache") assert config.model_cache_dir == tmp_path / "my_cache"
def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog): # check dry_run is True sets skip_load_validation to True config = TrainConfig(labels=labels_absolute_path, dry_run=True, skip_load_validation=False) assert config.skip_load_validation assert "Turning off video loading check since dry_run=True." in caplog.text # if dry run is False, skip_load_validation is unchanged config = TrainConfig(labels=labels_absolute_path, dry_run=False, skip_load_validation=False) assert not config.skip_load_validation
def test_from_scratch(labels_absolute_path): config = TrainConfig(labels=labels_absolute_path, from_scratch=True, checkpoint=None) assert config.model_name == "time_distributed" assert config.from_scratch assert config.checkpoint is None with pytest.raises(ValueError) as error: TrainConfig(labels=labels_absolute_path, from_scratch=True, model_name=None) assert "If from_scratch=True, model_name cannot be None." == error.value.errors( )[0]["msg"]
def test_labels_with_all_null_species(labels_absolute_path): labels = pd.read_csv(labels_absolute_path) labels["label"] = np.nan with pytest.raises(ValueError) as error: TrainConfig(labels=labels) assert "Species cannot be null for all videos." == error.value.errors( )[0]["msg"]
def test_videos_cannot_be_loaded(tmp_path, labels_absolute_path, caplog): files_df = pd.read_csv(labels_absolute_path) # create bad files for i in np.arange(2): bad_file = tmp_path / f"bad_file_{i}.mp4" bad_file.touch() files_df = files_df.append( { "filepath": bad_file, "label": "gorilla", "split": "train" }, ignore_index=True) files_df.to_csv(tmp_path / "labels_with_non_loadable_videos.csv") config = PredictConfig(filepaths=tmp_path / "labels_with_non_loadable_videos.csv") assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text assert len(config.filepaths) == (len(files_df) - 2) config = TrainConfig(labels=tmp_path / "labels_with_non_loadable_videos.csv") assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text assert len(config.labels) == (len(files_df) - 2)
def test_labels_with_invalid_split(labels_absolute_path): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "split"] = "test" with pytest.raises(ValueError) as error: TrainConfig(labels=labels) assert ( "Found the following invalid values for `split`: {'test'}. `split` can only contain `train`, `val`, or `holdout.`" ) == error.value.errors()[0]["msg"]
def test_label_column(tmp_path, labels_absolute_path): pd.read_csv(labels_absolute_path).rename(columns={ "label": "animal" }).to_csv(tmp_path / "bad_label_column.csv") with pytest.raises(ValidationError) as error: TrainConfig(labels=tmp_path / "bad_label_column.csv") assert "must contain `filepath` and `label` columns" in error.value.errors( )[0]["msg"]
def test_checkpoint_sets_model_to_none(labels_absolute_path, dummy_trained_model_checkpoint): config = TrainConfig( labels=labels_absolute_path, checkpoint=dummy_trained_model_checkpoint, skip_load_validation=True, ) assert config.model_name is None
def test_labels_with_partially_null_split(labels_absolute_path): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "split"] = np.nan with pytest.raises(ValueError) as error: TrainConfig(labels=labels) assert ( "Found 1 row(s) with null `split`. Fill in these rows with either `train`, `val`, or `holdout`" ) in error.value.errors()[0]["msg"]
def test_train_data_dir_only(): with pytest.raises(ValidationError) as error: TrainConfig(data_dir=TEST_VIDEOS_DIR) # labels is missing assert error.value.errors() == [{ "loc": ("labels", ), "msg": "field required", "type": "value_error.missing" }]
def test_train_data_dir_and_labels(tmp_path, labels_relative_path, labels_absolute_path): # correct data dir config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path) assert config.data_dir is not None assert config.labels is not None # data dir ignored if absolute path provided in filepath config = TrainConfig(data_dir=tmp_path, labels=labels_absolute_path) assert config.data_dir is not None assert config.labels is not None assert not config.labels.filepath.str.startswith(str(tmp_path)).any() # incorrect data dir with relative filepaths with pytest.raises(ValidationError) as error: TrainConfig(data_dir=ASSETS_DIR, labels=labels_relative_path) assert "None of the video filepaths exist" in error.value.errors( )[0]["msg"]
def test_finetune_new_labels(labels_absolute_path, model, tmp_path): config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, scheduler_config="default", labels=pd.DataFrame([{"filepath": "kangaroo.mp4", "species_kangaroo": 1}]), model_cache_dir=tmp_path, ) assert model.species == ["kangaroo"]
def test_labels_split_proportions(labels_no_splits, tmp_path): config = TrainConfig( data_dir=TEST_VIDEOS_DIR, labels=labels_no_splits, split_proportions={ "a": 3, "b": 1 }, save_dir=tmp_path, ) assert config.labels.split.value_counts().to_dict() == {"a": 13, "b": 6}
def test_filepath_column(tmp_path, labels_absolute_path): pd.read_csv(labels_absolute_path).rename(columns={ "filepath": "video" }).to_csv(tmp_path / "bad_filepath_column.csv") # predict: filepaths with pytest.raises(ValidationError) as error: PredictConfig(filepaths=tmp_path / "bad_filepath_column.csv") assert "must contain a `filepath` column" in error.value.errors()[0]["msg"] # train: labels with pytest.raises(ValidationError) as error: TrainConfig(labels=tmp_path / "bad_filepath_column.csv") assert "must contain `filepath` and `label` columns" in error.value.errors( )[0]["msg"]
def test_default_video_loader_config(labels_absolute_path): # if no video loader is specified, use default for model config = ModelConfig( train_config=TrainConfig(labels=labels_absolute_path, skip_load_validation=True), video_loader_config=None, ) assert config.video_loader_config is not None config = ModelConfig( predict_config=PredictConfig(filepaths=labels_absolute_path, skip_load_validation=True), video_loader_config=None, ) assert config.video_loader_config is not None
def test_one_video_does_not_exist(tmp_path, labels_absolute_path, caplog): files_df = pd.read_csv(labels_absolute_path) # add a fake file files_df = files_df.append( { "filepath": "fake_file.mp4", "label": "gorilla", "split": "train" }, ignore_index=True) files_df.to_csv(tmp_path / "labels_with_fake_video.csv") config = PredictConfig(filepaths=tmp_path / "labels_with_fake_video.csv") assert "Skipping 1 file(s) that could not be found" in caplog.text # one fewer file than in original list since bad file is skipped assert len(config.filepaths) == (len(files_df) - 1) config = TrainConfig(labels=tmp_path / "labels_with_fake_video.csv") assert "Skipping 1 file(s) that could not be found" in caplog.text assert len(config.labels) == (len(files_df) - 1)
def train_model( train_config: TrainConfig, video_loader_config: Optional[VideoLoaderConfig] = None, ): """Trains a model. Args: train_config (TrainConfig): Pydantic config for training. video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos. If None, will use default for model specified in TrainConfig. """ # get default VLC for model if not specified if video_loader_config is None: video_loader_config = ModelConfig( train_config=train_config, video_loader_config=video_loader_config ).video_loader_config # set up model model = instantiate_model( checkpoint=train_config.checkpoint, scheduler_config=train_config.scheduler_config, weight_download_region=train_config.weight_download_region, model_cache_dir=train_config.model_cache_dir, labels=train_config.labels, from_scratch=train_config.from_scratch, model_name=train_config.model_name, predict_all_zamba_species=train_config.predict_all_zamba_species, ) data_module = ZambaDataModule( video_loader_config=video_loader_config, transform=MODEL_MAPPING[model.__class__.__name__]["transform"], train_metadata=train_config.labels, batch_size=train_config.batch_size, num_workers=train_config.num_workers, ) validate_species(model, data_module) train_config.save_dir.mkdir(parents=True, exist_ok=True) # add folder version_n that auto increments if we are not overwriting tensorboard_version = train_config.save_dir.name if train_config.overwrite else None tensorboard_save_dir = ( train_config.save_dir.parent if train_config.overwrite else train_config.save_dir ) tensorboard_logger = TensorBoardLogger( save_dir=tensorboard_save_dir, name=None, version=tensorboard_version, default_hp_metric=False, ) logging_and_save_dir = ( tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir ) model_checkpoint = ModelCheckpoint( dirpath=logging_and_save_dir, filename=train_config.model_name, monitor=train_config.early_stopping_config.monitor if train_config.early_stopping_config is not None else None, mode=train_config.early_stopping_config.mode if train_config.early_stopping_config is not None else "min", ) callbacks = [model_checkpoint] if train_config.early_stopping_config is not None: callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict())) if train_config.backbone_finetune_config is not None: callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict())) trainer = pl.Trainer( gpus=train_config.gpus, max_epochs=train_config.max_epochs, auto_lr_find=train_config.auto_lr_find, logger=tensorboard_logger, callbacks=callbacks, fast_dev_run=train_config.dry_run, accelerator="ddp" if data_module.multiprocessing_context is not None else None, plugins=DDPPlugin(find_unused_parameters=False) if data_module.multiprocessing_context is not None else None, ) if video_loader_config.cache_dir is None: logger.info("No cache dir is specified. Videos will not be cached.") else: logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.") if train_config.auto_lr_find: logger.info("Finding best learning rate.") trainer.tune(model, data_module) try: git_hash = git.Repo(search_parent_directories=True).head.object.hexsha except git.exc.InvalidGitRepositoryError: git_hash = None configuration = { "git_hash": git_hash, "model_class": model.model_class, "species": model.species, "starting_learning_rate": model.lr, "train_config": json.loads(train_config.json(exclude={"labels"})), "training_start_time": datetime.utcnow().isoformat(), "video_loader_config": json.loads(video_loader_config.json()), } if not train_config.dry_run: config_path = Path(logging_and_save_dir) / "train_configuration.yaml" config_path.parent.mkdir(exist_ok=True, parents=True) logger.info(f"Writing out full configuration to {config_path}.") with config_path.open("w") as fp: yaml.dump(configuration, fp) logger.info("Starting training...") trainer.fit(model, data_module) if not train_config.dry_run: if trainer.datamodule.test_dataloader() is not None: logger.info("Calculating metrics on holdout set.") test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0] with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp: json.dump(test_metrics, fp, indent=2) if trainer.datamodule.val_dataloader() is not None: logger.info("Calculating metrics on validation set.") val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0] with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp: json.dump(val_metrics, fp, indent=2) return trainer
def test_train_labels_only(labels_absolute_path): config = TrainConfig(labels=labels_absolute_path) assert config.labels is not None
def train_metadata(labels_absolute_path) -> pd.DataFrame: return TrainConfig(labels=labels_absolute_path).labels
def time_distributed_checkpoint(labels_absolute_path) -> os.PathLike: return TrainConfig(labels=labels_absolute_path, model_name="time_distributed").checkpoint
def test_resume_subset_labels(labels_absolute_path, model, tmp_path): config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None), # pick species that is present in all models labels=pd.DataFrame([{"filepath": "bird.mp4", "species_bird": 1}]), model_cache_dir=tmp_path, ) assert model.hparams["scheduler"] == "StepLR" if config.model_name == "european": assert model.species == [ "bird", "blank", "domestic_cat", "european_badger", "european_beaver", "european_hare", "european_roe_deer", "north_american_raccoon", "red_fox", "weasel", "wild_boar", ] else: assert model.species == [ "aardvark", "antelope_duiker", "badger", "bat", "bird", "blank", "cattle", "cheetah", "chimpanzee_bonobo", "civet_genet", "elephant", "equid", "forest_buffalo", "fox", "giraffe", "gorilla", "hare_rabbit", "hippopotamus", "hog", "human", "hyena", "large_flightless_bird", "leopard", "lion", "mongoose", "monkey_prosimian", "pangolin", "porcupine", "reptile", "rodent", "small_cat", "wild_dog_jackal", ]
def test_labels_with_partially_null_species(labels_absolute_path, caplog): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "label"] = np.nan TrainConfig(labels=labels) assert "Found 1 filepath(s) with no label. Will skip." in caplog.text
def test_labels_with_all_null_split(labels_absolute_path, caplog): labels = pd.read_csv(labels_absolute_path) labels["split"] = np.nan TrainConfig(labels=labels) assert "Split column is entirely null. Will generate splits automatically" in caplog.text
def train( data_dir: Path = typer.Option(None, exists=True, help="Path to folder containing videos."), labels: Path = typer.Option(None, exists=True, help="Path to csv containing video labels."), model: ModelEnum = typer.Option( "time_distributed", help= "Model to train. Model will be superseded by checkpoint if provided.", ), checkpoint: Path = typer.Option( None, exists=True, help= "Model checkpoint path to use for training. If provided, model is not required.", ), config: Path = typer.Option( None, exists=True, help= "Specify options using yaml configuration file instead of through command line options.", ), batch_size: int = typer.Option(None, help="Batch size to use for training."), gpus: int = typer.Option( None, help= "Number of GPUs to use for training. If not specifiied, will use all GPUs found on machine.", ), dry_run: bool = typer.Option( None, help="Runs one batch of train and validation to check for bugs.", ), save_dir: Path = typer.Option( None, help= "An optional directory in which to save the model checkpoint and configuration file. If not specified, will save to a `version_n` folder in your working directory.", ), num_workers: int = typer.Option( None, help="Number of subprocesses to use for data loading.", ), weight_download_region: RegionEnum = typer.Option( None, help="Server region for downloading weights."), skip_load_validation: bool = typer.Option( None, help= "Skip check that verifies all videos can be loaded prior to training. Only use if you're very confident all your videos can be loaded.", ), yes: bool = typer.Option( False, "--yes", "-y", help= "Skip confirmation of configuration and proceed right to training.", ), ): """Train a model on your labeled data. If an argument is specified in both the command line and in a yaml file, the command line input will take precedence. """ if config is not None: with config.open() as f: config_dict = yaml.safe_load(f) config_file = config else: with (MODELS_DIRECTORY / f"{model.value}/config.yaml").open() as f: config_dict = yaml.safe_load(f) config_file = None if "video_loader_config" in config_dict.keys(): video_loader_config = VideoLoaderConfig( **config_dict["video_loader_config"]) else: video_loader_config = None train_dict = config_dict["train_config"] # override if any command line arguments are passed if data_dir is not None: train_dict["data_dir"] = data_dir if labels is not None: train_dict["labels"] = labels if model != "time_distributed": train_dict["model_name"] = model if checkpoint is not None: train_dict["checkpoint"] = checkpoint if batch_size is not None: train_dict["batch_size"] = batch_size if gpus is not None: train_dict["gpus"] = gpus if dry_run is not None: train_dict["dry_run"] = dry_run if save_dir is not None: train_dict["save_dir"] = save_dir if num_workers is not None: train_dict["num_workers"] = num_workers if weight_download_region is not None: train_dict["weight_download_region"] = weight_download_region if skip_load_validation is not None: train_dict["skip_load_validation"] = skip_load_validation try: manager = ModelManager( ModelConfig( video_loader_config=video_loader_config, train_config=TrainConfig(**train_dict), )) except ValidationError as ex: logger.error("Invalid configuration.") raise typer.Exit(ex) config = manager.config # get species to confirm spacer = "\n\t- " species = spacer + spacer.join( sorted([ c.split("species_", 1)[1] for c in config.train_config.labels.filter(regex="species").columns ])) msg = f"""The following configuration will be used for training: Config file: {config_file} Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_dir")} Labels csv: {labels if labels is not None else config_dict["train_config"].get("labels")} Species: {species} Model name: {config.train_config.model_name} Checkpoint: {checkpoint if checkpoint is not None else config_dict["train_config"].get("checkpoint")} Batch size: {config.train_config.batch_size} Number of workers: {config.train_config.num_workers} GPUs: {config.train_config.gpus} Dry run: {config.train_config.dry_run} Save directory: {config.train_config.save_dir} Weight download region: {config.train_config.weight_download_region} """ if yes: typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to train.") else: yes = typer.confirm( f"{msg}\n\nIs this correct?", abort=False, default=True, ) if yes: # kick off training manager.train()