Ejemplo n.º 1
0
def test_get_embeddings():
    backbone_arch = "resnet18"
    model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch)

    data_path = fetch_example_dataset(dataset=ExampleData.TINY10)
    dataset = ImageSingletDataset(
        data_dir=data_path,
        stage="train",
        tile_type=TileType.ANCHOR,
        transform=get_transforms(step="predict",
                                 normalize_for_arch=backbone_arch),
    )

    # direct
    dl_predict = DataLoader(dataset, batch_size=32)
    batched_results = [model.forward(x_batch) for x_batch in dl_predict]
    results = np.vstack([v.cpu().detach().numpy() for v in batched_results])

    # via utily function
    da_embeddings = get_embeddings(tile_dataset=dataset,
                                   model=model,
                                   prediction_batch_size=16)

    Ntiles, Ndim = results.shape

    assert int(da_embeddings.tile_id.count()) == Ntiles
    assert int(da_embeddings.emb_dim.count()) == Ndim
Ejemplo n.º 2
0
def test_dendrogram_plot_triplets():
    # use a model with default resnet weights to generate some embedding
    # vectors to plot with
    backbone_arch = "resnet18"
    model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch)

    data_path = fetch_example_dataset(dataset=ExampleData.SMALL100)
    tile_dataset = ImageTripletDataset(
        data_dir=data_path,
        stage="train",
        transform=get_transforms(step="predict",
                                 normalize_for_arch=backbone_arch),
    )

    da_embeddings = get_embeddings(tile_dataset=tile_dataset,
                                   model=model,
                                   prediction_batch_size=16)
    for sampling_method in [
            "random", "center_dist", "best_triplets", "worst_triplets"
    ]:
        interpretation_plot.dendrogram(
            da_embeddings=da_embeddings,
            sampling_method=sampling_method,
            tile_type="anchor",
        )
Ejemplo n.º 3
0
def test_finetune_pretrained():
    trainer = pl.Trainer(max_epochs=5, callbacks=[HeadFineTuner()])
    arch = "resnet18"
    model = TripletTrainerModel(pretrained=True, base_arch=arch)
    data_path = fetch_example_dataset(dataset=ExampleData.TINY10)
    datamodule = TripletTrainerDataModule(
        data_dir=data_path, batch_size=2, normalize_for_arch=arch
    )
    trainer.fit(model=model, datamodule=datamodule)
Ejemplo n.º 4
0
def test_train_new_with_preloading():
    trainer = pl.Trainer(max_epochs=5)
    arch = "resnet18"
    model = TripletTrainerModel(pretrained=False, base_arch=arch)
    data_path = fetch_example_dataset(dataset=ExampleData.TINY10)
    datamodule = TripletTrainerDataModule(
        data_dir=data_path, batch_size=2, normalize_for_arch=arch, preload_data=True
    )
    trainer.fit(model=model, datamodule=datamodule)
Ejemplo n.º 5
0
def test_train_new_anti_aliased():
    trainer = pl.Trainer(max_epochs=5, gpus=N_GPUS)
    arch = "resnet18"
    model = TripletTrainerModel(pretrained=False,
                                base_arch=arch,
                                anti_aliased_backbone=True)
    data_path = fetch_example_dataset(dataset=ExampleData.TINY10)
    datamodule = TripletTrainerDataModule(data_dir=data_path,
                                          batch_size=2,
                                          normalize_for_arch=arch)
    trainer.fit(model=model, datamodule=datamodule)
Ejemplo n.º 6
0
def test_train_new_onecycle():
    lr_monitor = LearningRateMonitor(logging_interval="step")
    trainer = OneCycleTrainer(max_epochs=5,
                              callbacks=[lr_monitor],
                              gpus=N_GPUS)
    arch = "resnet18"
    model = TripletTrainerModel(pretrained=False, base_arch=arch)
    data_path = fetch_example_dataset(dataset=ExampleData.TINY10)
    datamodule = TripletTrainerDataModule(data_dir=data_path,
                                          batch_size=2,
                                          normalize_for_arch=arch)
    trainer.fit(model=model, datamodule=datamodule)
Ejemplo n.º 7
0
def main(model_path, image_path):
    model_path = Path(model_path)

    model_name = model_path.parent.parent.name
    model = TripletTrainerModel.load_from_checkpoint(model_path)

    N_tile = (256, 256)
    make_plot(model=model,
              image_path=image_path,
              model_name=model_name,
              N_tile=N_tile)

    scene_id = Path(image_path).name.split(".")[0]
    output_img_filename = f"{model_name}.{scene_id}.PCA012_rgb.png"
    plt.savefig(output_img_filename)
    print(f"wrote image to `{output_img_filename}`")
Ejemplo n.º 8
0
def test_dendrogram_plot():
    # use a model with default resnet weights to generate some embedding
    # vectors to plot with
    backbone_arch = "resnet18"
    model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch)

    data_path = fetch_example_dataset(dataset=ExampleData.SMALL100)
    tile_dataset = ImageSingletDataset(
        data_dir=data_path,
        stage="train",
        tile_type=TileType.ANCHOR,
        transform=get_transforms(step="predict",
                                 normalize_for_arch=backbone_arch),
    )

    da_embeddings = get_embeddings(tile_dataset=tile_dataset,
                                   model=model,
                                   prediction_batch_size=16)
    interpretation_plot.dendrogram(da_embeddings=da_embeddings)
Ejemplo n.º 9
0
def main(model_path, data_path):
    if model_path.startswith("ex://"):
        model_name = model_path[5:]
        available_models = [m.name for m in list(PretrainedModel)]
        if model_name not in available_models:
            raise Exception(
                f"pretrained model `{model_name}` not found."
                f" available models: {', '.join(available_models)}"
            )
        model = load_pretrained_model(PretrainedModel[model_name])
    else:
        model = TripletTrainerModel.load_from_checkpoint(model_path)

    if data_path.startswith("ex://"):
        dset_name = data_path[5:]
        available_dsets = [m.name for m in list(ExampleData)]
        if model_name not in available_models:
            raise Exception(
                f"example dataset `{dset_name}` not found."
                f" available dataset: {', '.join(available_dsets)}"
            )
        model = load_pretrained_model(PretrainedModel[model_name])
        data_path = fetch_example_dataset(dataset=ExampleData[dset_name])
    transforms = get_transforms(step="predict", normalize_for_arch=model.base_arch)
    dset = ImageTripletDataset(data_dir=data_path, stage="train", transform=transforms)

    embs_id = _make_hash(f"{model_path}__{data_path}")
    fpath_embs = Path(f"embs-{embs_id}.nc")

    if not fpath_embs.exists():
        da_embs = get_embeddings(
            tile_dataset=dset, model=model, prediction_batch_size=64
        )
        da_embs.to_netcdf(fpath_embs)
    else:
        da_embs = xr.open_dataarray(fpath_embs)
        print(f"using saved embeddings from `{fpath_embs}`")

    _save_embeddings(da_embs=da_embs, dset=dset)
Ejemplo n.º 10
0
def test_rectpred_sliding_window_inference():
    # use a model with default resnet weights to generate some embedding
    # vectors to plot with
    backbone_arch = "resnet18"
    model = TripletTrainerModel(pretrained=True, base_arch=backbone_arch)
    # TODO: make this a property of the model
    N_tile = (256, 256)

    img = Image.open(RECTPRED_IMG_EXAMPLE_PATH)
    step = (500, 200)
    da_emb_rect = make_sliding_tile_model_predictions(img=img,
                                                      model=model,
                                                      step=step)

    nx_img, ny_img = img.size

    # number of tiles expected in each direction
    nxt = (nx_img - N_tile[0] + step[0]) // step[0]
    nyt = (ny_img - N_tile[1] + step[1]) // step[1]

    assert da_emb_rect.emb_dim.count() == model.n_embedding_dims
    assert da_emb_rect.i0.count() == nxt
    assert da_emb_rect.j0.count() == nyt