def test_load_from_weights(): model = load_pretrained_model( pretrained_model=PretrainedModel.FIXED_NORM_STAGE2) # there was a bug where fetching, loading and producing embeddings the same # way again yielded different embeddings. I need to check that using a # loaded network always gives the same result model2 = load_pretrained_model( pretrained_model=PretrainedModel.FIXED_NORM_STAGE2) assert_models_equal(model, model2) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=model.base_arch), ) da_emb = get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) da_emb2 = get_embeddings(tile_dataset=dataset, model=model2, prediction_batch_size=16) np.testing.assert_allclose(da_emb, da_emb2)
def test_dendrogram_plot_triplets(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageTripletDataset( data_dir=data_path, stage="train", transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) for sampling_method in [ "random", "center_dist", "best_triplets", "worst_triplets" ]: interpretation_plot.dendrogram( da_embeddings=da_embeddings, sampling_method=sampling_method, tile_type="anchor", )
def test_get_embeddings(): backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) # direct dl_predict = DataLoader(dataset, batch_size=32) batched_results = [model.forward(x_batch) for x_batch in dl_predict] results = np.vstack([v.cpu().detach().numpy() for v in batched_results]) # via utily function da_embeddings = get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) Ntiles, Ndim = results.shape assert int(da_embeddings.tile_id.count()) == Ntiles assert int(da_embeddings.emb_dim.count()) == Ndim
def test_finetune_pretrained(): trainer = pl.Trainer(max_epochs=5, callbacks=[HeadFineTuner()]) arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_with_preloading(): trainer = pl.Trainer(max_epochs=5) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch, preload_data=True ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_anti_aliased(): trainer = pl.Trainer(max_epochs=5, gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch, anti_aliased_backbone=True) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_onecycle(): lr_monitor = LearningRateMonitor(logging_interval="step") trainer = OneCycleTrainer(max_epochs=5, callbacks=[lr_monitor], gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def test_dendrogram_plot(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) interpretation_plot.dendrogram(da_embeddings=da_embeddings)
def test_isomap2d(): data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) model = load_pretrained_model( pretrained_model=PretrainedModel.FIXED_NORM_STAGE2) dataset = ImageTripletDataset( data_dir=data_path, transform=get_transforms(step="predict", normalize_for_arch=model.base_arch), stage="train", ) da_embs = get_embeddings( tile_dataset=dataset, model=model, prediction_batch_size=4, ) isomap2d.make_isomap_reference_plot(da_embs=da_embs, tile_size=0.1)
def main(model_path, data_path): if model_path.startswith("ex://"): model_name = model_path[5:] available_models = [m.name for m in list(PretrainedModel)] if model_name not in available_models: raise Exception( f"pretrained model `{model_name}` not found." f" available models: {', '.join(available_models)}" ) model = load_pretrained_model(PretrainedModel[model_name]) else: model = TripletTrainerModel.load_from_checkpoint(model_path) if data_path.startswith("ex://"): dset_name = data_path[5:] available_dsets = [m.name for m in list(ExampleData)] if model_name not in available_models: raise Exception( f"example dataset `{dset_name}` not found." f" available dataset: {', '.join(available_dsets)}" ) model = load_pretrained_model(PretrainedModel[model_name]) data_path = fetch_example_dataset(dataset=ExampleData[dset_name]) transforms = get_transforms(step="predict", normalize_for_arch=model.base_arch) dset = ImageTripletDataset(data_dir=data_path, stage="train", transform=transforms) embs_id = _make_hash(f"{model_path}__{data_path}") fpath_embs = Path(f"embs-{embs_id}.nc") if not fpath_embs.exists(): da_embs = get_embeddings( tile_dataset=dset, model=model, prediction_batch_size=64 ) da_embs.to_netcdf(fpath_embs) else: da_embs = xr.open_dataarray(fpath_embs) print(f"using saved embeddings from `{fpath_embs}`") _save_embeddings(da_embs=da_embs, dset=dset)
else: raise Exception if len(differences) > 0: msg = ( f"There were differences found in {len(differences)} out of " f"{len(model_1.state_dict())} layers: " + ", ".join(differences.keys()) ) raise Exception(msg) def test_load_from_weights(): model = load_pretrained_model(pretrained_model=PretrainedModel.FIXED_NORM_STAGE2) >>>>>>> upstream/master data_path = fetch_example_dataset(dataset=ExampleData.TINY10) dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=model.base_arch), ) <<<<<<< HEAD get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) ======= da_emb = get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) # there was a bug where fetching, loading and producing embeddings the same # way again yielded different embeddings. I need to check that using a # loaded network always gives the same result model2 = load_pretrained_model(pretrained_model=PretrainedModel.FIXED_NORM_STAGE2)
def test_grid_overview_plot(): data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageSingletDataset(data_dir=data_path, stage="train", tile_type=TileType.ANCHOR) interpretation_plot.grid_overview(tile_dataset=tile_dataset, points=10)