def test_dendrogram_plot_triplets(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageTripletDataset( data_dir=data_path, stage="train", transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) for sampling_method in [ "random", "center_dist", "best_triplets", "worst_triplets" ]: interpretation_plot.dendrogram( da_embeddings=da_embeddings, sampling_method=sampling_method, tile_type="anchor", )
def test_get_embeddings(): backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) # direct dl_predict = DataLoader(dataset, batch_size=32) batched_results = [model.forward(x_batch) for x_batch in dl_predict] results = np.vstack([v.cpu().detach().numpy() for v in batched_results]) # via utily function da_embeddings = get_embeddings(tile_dataset=dataset, model=model, prediction_batch_size=16) Ntiles, Ndim = results.shape assert int(da_embeddings.tile_id.count()) == Ntiles assert int(da_embeddings.emb_dim.count()) == Ndim
def test_finetune_pretrained(): trainer = pl.Trainer(max_epochs=5, callbacks=[HeadFineTuner()]) arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_with_preloading(): trainer = pl.Trainer(max_epochs=5) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule( data_dir=data_path, batch_size=2, normalize_for_arch=arch, preload_data=True ) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_anti_aliased(): trainer = pl.Trainer(max_epochs=5, gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch, anti_aliased_backbone=True) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def test_train_new_onecycle(): lr_monitor = LearningRateMonitor(logging_interval="step") trainer = OneCycleTrainer(max_epochs=5, callbacks=[lr_monitor], gpus=N_GPUS) arch = "resnet18" model = TripletTrainerModel(pretrained=False, base_arch=arch) data_path = fetch_example_dataset(dataset=ExampleData.TINY10) datamodule = TripletTrainerDataModule(data_dir=data_path, batch_size=2, normalize_for_arch=arch) trainer.fit(model=model, datamodule=datamodule)
def test_dendrogram_plot(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) data_path = fetch_example_dataset(dataset=ExampleData.SMALL100) tile_dataset = ImageSingletDataset( data_dir=data_path, stage="train", tile_type=TileType.ANCHOR, transform=get_transforms(step="predict", normalize_for_arch=backbone_arch), ) da_embeddings = get_embeddings(tile_dataset=tile_dataset, model=model, prediction_batch_size=16) interpretation_plot.dendrogram(da_embeddings=da_embeddings)
def test_rectpred_sliding_window_inference(): # use a model with default resnet weights to generate some embedding # vectors to plot with backbone_arch = "resnet18" model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch) # TODO: make this a property of the model N_tile = (256, 256) img = Image.open(RECTPRED_IMG_EXAMPLE_PATH) step = (500, 200) da_emb_rect = make_sliding_tile_model_predictions(img=img, model=model, step=step) nx_img, ny_img = img.size # number of tiles expected in each direction nxt = (nx_img - N_tile[0] + step[0]) // step[0] nyt = (ny_img - N_tile[1] + step[1]) // step[1] assert da_emb_rect.emb_dim.count() == model.n_embedding_dims assert da_emb_rect.i0.count() == nxt assert da_emb_rect.j0.count() == nyt