def test_get_embeddings(): backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) # direct dl_predict = DataLoader(dataset, batch_size=32) batched_results = [model.forward(x_batch) for x_batch in dl_predict] results = np.vstack([v.cpu().detach().numpy() for v in batched_results]) # via utily function da_embeddings = get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) Ntiles, Ndim = results.shape assert int(da_embeddings.tile_id.count()) == Ntiles assert int(da_embeddings.emb_dim.count()) == Ndim
def test_dendrogram_plot_triplets(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageTripletDataset( data_dir=data_path, stage="train", transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) for sampling_method in [ "random", "center_dist", "best_triplets", "worst_triplets" ]: interpretation_plot.dendrogram( da_embeddings=da_embeddings, sampling_method=sampling_method, tile_type="anchor", )
def test_finetune_pretrained(): trainer = pl.Trainer(max_epochs=5, callbacks=[HeadFineTuner()]) arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_with_preloading(): trainer = pl.Trainer(max_epochs=5) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch, preload_data=True ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_anti_aliased(): trainer = pl.Trainer(max_epochs=5, gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch, anti_aliased_backbone=True) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_onecycle(): lr_monitor = LearningRateMonitor(logging_interval="step") trainer = OneCycleTrainer(max_epochs=5, callbacks=[lr_monitor], gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def main(model_path, image_path): model_path = Path(model_path) model_name = model_path.parent.parent.name model = TripletTrainerModel.load_from_checkpoint(model_path) N_tile = (256, 256) make_plot(model=model, image_path=image_path, model_name=model_name, N_tile=N_tile) scene_id = Path(image_path).name.split(".")[0] output_img_filename = f"{model_name}.{scene_id}.PCA012_rgb.png" plt.savefig(output_img_filename) print(f"wrote image to `{output_img_filename}`")
def test_dendrogram_plot(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) interpretation_plot.dendrogram(da_embeddings=da_embeddings)
def main(model_path, data_path): if model_path.startswith("ex://"): model_name = model_path[5:] available_models = [m.name for m in list(PretrainedModel)] if model_name not in available_models: raise Exception( f"pretrained model `{model_name}` not found." f" available models: {', '.join(available_models)}" ) model = load_pretrained_model(PretrainedModel[model_name]) else: model = TripletTrainerModel.load_from_checkpoint(model_path) if data_path.startswith("ex://"): dset_name = data_path[5:] available_dsets = [m.name for m in list(ExampleData)] if model_name not in available_models: raise Exception( f"example dataset `{dset_name}` not found." f" available dataset: {', '.join(available_dsets)}" ) model = load_pretrained_model(PretrainedModel[model_name]) data_path = fetch_example_dataset(dataset=ExampleData[dset_name]) transforms = get_transforms(step="predict", normalize_for_arch=model.base_arch) dset = ImageTripletDataset(data_dir=data_path, stage="train", transform=transforms) embs_id = _make_hash(f"{model_path}__{data_path}") fpath_embs = Path(f"embs-{embs_id}.nc") if not fpath_embs.exists(): da_embs = get_embeddings( tile_dataset=dset, model=model, prediction_batch_size=64 ) da_embs.to_netcdf(fpath_embs) else: da_embs = xr.open_dataarray(fpath_embs) print(f"using saved embeddings from `{fpath_embs}`") _save_embeddings(da_embs=da_embs, dset=dset)
def test_rectpred_sliding_window_inference(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) # TODO: make this a property of the model N_tile = (256, 256) img = Image.open(RECTPRED_IMG_EXAMPLE_PATH) step = (500, 200) da_emb_rect = make_sliding_tile_model_predictions(img=img, model=model, step=step) nx_img, ny_img = img.size # number of tiles expected in each direction nxt = (nx_img - N_tile[0] + step[0]) // step[0] nyt = (ny_img - N_tile[1] + step[1]) // step[1] assert da_emb_rect.emb_dim.count() == model.n_embedding_dims assert da_emb_rect.i0.count() == nxt assert da_emb_rect.j0.count() == nyt