def test_finetune_new_labels(labels_absolute_path, model, tmp_path): config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, scheduler_config="default", labels=pd.DataFrame([{"filepath": "kangaroo.mp4", "species_kangaroo": 1}]), model_cache_dir=tmp_path, ) assert model.species == ["kangaroo"]
def test_remove_scheduler(time_distributed_checkpoint, tmp_path): """Tests that a scheduler config with None values removes the scheduler on the model.""" remove_scheduler_model = instantiate_model( checkpoint=time_distributed_checkpoint, weight_download_region="us", scheduler_config=SchedulerConfig(scheduler=None), labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, ) # pretrained model has scheduler but this is overridden with SchedulerConfig assert remove_scheduler_model.hparams["scheduler"] is None
def test_scheduler_ignored_for_prediction(dummy_checkpoint, tmp_path): """Tests whether we can instantiate a model for prediction and ignore scheduler config.""" original_hyperparams = torch.load(dummy_checkpoint)["hyper_parameters"] assert original_hyperparams["scheduler"] is None model = instantiate_model( checkpoint=dummy_checkpoint, weight_download_region="us", scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None), labels=None, model_cache_dir=tmp_path, ) # since labels is None, we are predicting. as a result, hparams are not updated assert model.hparams["scheduler"] is None
def test_scheduler_used_if_passed(time_distributed_checkpoint, tmp_path): """Tests that scheduler config gets used and overrides scheduler on time distributed training.""" scheduler_passed_model = instantiate_model( checkpoint=time_distributed_checkpoint, weight_download_region="us", scheduler_config=SchedulerConfig(scheduler="StepLR"), labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, ) # hparams reflect user specified scheduler config assert scheduler_passed_model.hparams["scheduler"] == "StepLR" # if no scheduler params are passed, will be None (use PTL default for that scheduler) assert scheduler_passed_model.hparams["scheduler_params"] is None # check scheduler params get used scheduler_params_passed_model = instantiate_model( checkpoint=time_distributed_checkpoint, weight_download_region="us", scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params={"gamma": 0.3}), labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, ) assert scheduler_params_passed_model.hparams["scheduler_params"] == {"gamma": 0.3}
def test_default_scheduler_used(time_distributed_checkpoint, tmp_path): """Tests instantiate model uses the default scheduler from the hparams on the model.""" default_scheduler_passed_model = instantiate_model( checkpoint=time_distributed_checkpoint, weight_download_region="us", scheduler_config="default", labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, ) # with "default" scheduler_config, hparams from training are used assert default_scheduler_passed_model.hparams["scheduler"] == "MultiStepLR" assert default_scheduler_passed_model.hparams["scheduler_params"] == dict( milestones=[3], gamma=0.5, verbose=True )
def test_head_replaced_for_new_species(dummy_trained_model_checkpoint, tmp_path): """Tests that training a model using labels that are a not subset of the model species finetunes the model and replaces the model head.""" original_model = DummyZambaVideoClassificationLightningModule.from_disk( dummy_trained_model_checkpoint ) model = instantiate_model( checkpoint=dummy_trained_model_checkpoint, weight_download_region="us", scheduler_config="default", labels=pd.DataFrame([{"filepath": "alien.mp4", "species_alien": 1}]), model_cache_dir=tmp_path, ) assert (model.head.weight != original_model.head.weight).all() assert model.hparams["species"] == ["alien"] assert model.head.out_features == 1
def test_not_predict_all_zamba_species(dummy_trained_model_checkpoint, tmp_path): """Tests that training a model using labels that are a subset of the model species but with predict_all_zamba_species=False replaces the model head.""" original_model = DummyZambaVideoClassificationLightningModule.from_disk( dummy_trained_model_checkpoint ) model = instantiate_model( checkpoint=dummy_trained_model_checkpoint, weight_download_region="us", scheduler_config="default", labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, predict_all_zamba_species=False, ) assert (model.head.weight != original_model.head.weight).all() assert model.hparams["species"] == [ "gorilla", ] assert model.model[-1].out_features == 1
def test_head_not_replaced_for_species_subset(dummy_trained_model_checkpoint, tmp_path): """Tests that training a model using labels that are a subset of the model species resumes model training without replacing the model head.""" original_model = DummyZambaVideoClassificationLightningModule.from_disk( dummy_trained_model_checkpoint ) model = instantiate_model( checkpoint=dummy_trained_model_checkpoint, weight_download_region="us", scheduler_config="default", labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]), model_cache_dir=tmp_path, ) assert (model.head.weight == original_model.head.weight).all() assert model.hparams["species"] == [ "antelope_duiker", "elephant", "gorilla", ] assert model.model[-1].out_features == 3
def test_resume_subset_labels(labels_absolute_path, model, tmp_path): config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None), # pick species that is present in all models labels=pd.DataFrame([{"filepath": "bird.mp4", "species_bird": 1}]), model_cache_dir=tmp_path, ) assert model.hparams["scheduler"] == "StepLR" if config.model_name == "european": assert model.species == [ "bird", "blank", "domestic_cat", "european_badger", "european_beaver", "european_hare", "european_roe_deer", "north_american_raccoon", "red_fox", "weasel", "wild_boar", ] else: assert model.species == [ "aardvark", "antelope_duiker", "badger", "bat", "bird", "blank", "cattle", "cheetah", "chimpanzee_bonobo", "civet_genet", "elephant", "equid", "forest_buffalo", "fox", "giraffe", "gorilla", "hare_rabbit", "hippopotamus", "hog", "human", "hyena", "large_flightless_bird", "leopard", "lion", "mongoose", "monkey_prosimian", "pangolin", "porcupine", "reptile", "rodent", "small_cat", "wild_dog_jackal", ]