コード例 #1
0
def make_fig(args):
    # set up
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    isic_model_checkpoints = ['./trained_models/isic18/softm/', './trained_models/isic18/ensemble/',
                              './trained_models/isic18/mcdropout/', './trained_models/isic18/punet/']
    isic_models = [util.load_model_from_checkpoint(ckpt).to(
        device) for ckpt in isic_model_checkpoints]
    lidc_model_checkpoints = ['./trained_models/lidc/softm/', './trained_models/lidc/ensemble/',
                              './trained_models/lidc/mcdropout/', './trained_models/lidc/punet/']
    lidc_models = [util.load_model_from_checkpoint(ckpt).to(
        device) for ckpt in lidc_model_checkpoints]

    fig = Fig(
        rows=5,
        cols=args.images_each*2,
        title=None,
        figsize=(5, 3),
        background=True,
    )
    plt.tight_layout()
    colors = np.array(sns.color_palette("Paired")) * 255
    color_gt = colors[1]
    color_model = colors[7]
    np.random.seed(7)
    img_indices = np.random.choice(range(500), args.images_each*2)

    # plot isic dataset
    for i in range(args.images_each):
        plot_col(fig, i, img_indices[i], device, isic_models,
                 4, color_gt, color_model)

    # plot lidc dataset
    for i in range(args.images_each, args.images_each*2):
        plot_col(fig, i, img_indices[i], device, lidc_models,
                 2, color_gt, color_model)

    # adjust spacing
    plt.subplots_adjust(left=0, bottom=0, right=1,
                        top=1., wspace=0.06, hspace=0.06)

    for f in args.output_file:
        fig.save(f, close=False)
コード例 #2
0
def make_fig(model_checkpoints, output_file):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    for row, model_checkpoint in enumerate(tqdm(model_checkpoints, desc='Plotting Model outputs...')):
        model = util.load_model_from_checkpoint(model_checkpoint).to(device)
        datamodule = util.load_datamodule_for_model(
            model, batch_size=1)
        sample = load_sample(datamodule, idx=0, device=device)
        if row == 0:
            fig = set_up_figure(len(model_checkpoints), sample)
        predictions = predict(model, sample[0])
        plot_predictions(fig, row + 1, predictions, model.model_shortname())

    os.makedirs("./plots/", exist_ok=True)
    fig.save(output_file)
コード例 #3
0
def make_fig(model_checkpoint, output_file, indices):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = util.load_model_from_checkpoint(model_checkpoint[0]).to(device)
    datamodule = util.load_datamodule_for_model(model, batch_size=1)

    for row, idx in enumerate(indices):
        sample = load_train_sample(datamodule, idx=idx, device=device)
        if row == 0:
            fig = set_up_figure(len(indices), sample)
        predictions = predict(model, sample[0])
        plot_predictions(fig, row + 1, predictions, str(idx))

    os.makedirs("./plots/", exist_ok=True)
    fig.save(output_file)
コード例 #4
0
def make_fig(args):
    device = 'cpu'  # 'cuda' if torch.cuda.is_available() else 'cpu'
    models = [util.load_model_from_checkpoint(ckpt).to(
        device) for ckpt in args.model_checkpoints]
    datamodule = util.load_datamodule_for_model(
        models[0], batch_size=1)
    x, ys = load_sample(datamodule, idx=args.sample_idx, device=device)
    fig = set_up_figure(4, x, ys)

    for row, model in enumerate(tqdm(models, desc='Plotting Model outputs...')):
        predictions = predict(model, x)
        plot_predictions(fig, row + 1, predictions)

    # adjust spacing
    plt.subplots_adjust(left=0, bottom=0, right=1,
                        top=0.9, wspace=0.0, hspace=0.06)

    os.makedirs("./plots/", exist_ok=True)
    for f in args.output_file:
        fig.save(f, close=False)
コード例 #5
0
def make_fig(args):
    # set up
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = util.load_model_from_checkpoint(args.model_path).to(device)
    datamodule = util.load_datamodule_for_model(model, batch_size=1)

    for idx in tqdm(range(100), desc='Generating Images'):
        x, y = load_sample(datamodule, idx=idx, device=device)
        fig = Fig(
            rows=1,
            cols=2,
            title=None,
            figsize=None,
            background=True,
        )
        colors = np.array(sns.color_palette("Paired")) * 255
        color_gt = colors[1]
        color_model = colors[7]

        # draw samples
        pl.seed_everything(42)
        with torch.no_grad():
            p = model.pixel_wise_probabaility(x, sample_cnt=args.samples)
            _, y_pred = p.max(dim=1, keepdim=True)
            uncertainty = util.entropy(p)

        # plot image
        fig.plot_img(0, 0, x[0], vmin=0, vmax=1)

        # plot uncertainty heatmap
        fig.plot_overlay(
            0, 1, uncertainty[0], alpha=1, vmin=None, vmax=None, cmap='Greys', colorbar=True, colorbar_label="Model Uncertainty")

        # plot model prediction outline
        fig.plot_contour(0, 0, y_pred[0], contour_class=1, width=2, rgba=color_model
                         )
        fig.plot_contour(0, 1, y_pred[0], contour_class=1, width=2, rgba=color_model
                         )

        # plot gt seg outline
        fig.plot_contour(0, 0, y[0], contour_class=1, width=2, rgba=color_gt
                         )
        fig.plot_contour(0, 1, y[0], contour_class=1, width=2, rgba=color_gt
                         )

        # add legend
        from matplotlib import pyplot as plt
        from matplotlib.patches import Rectangle
        legend_data = [
            [0, color_gt, "GT Annotation"],
            [1, color_model, "Model Prediction"], ]
        handles = [
            Rectangle((0, 0), 1, 1, color=[v/255 for v in c]) for k, c, n in legend_data
        ]
        labels = [n for k, c, n in legend_data]

        plt.legend(handles, labels, ncol=len(legend_data))

        os.makedirs("./plots/", exist_ok=True)
        # fig.save(args.output_file)
        os.makedirs(args.output_folder, exist_ok=True)
        fig.save(os.path.join(args.output_folder,
                              f'test_{idx}.png'), close=False)
        fig.save(os.path.join(args.output_folder, f'test_{idx}.pdf'))
コード例 #6
0
def cli_main():
    pl.seed_everything(1234)
    supported_models = util.get_supported_models()

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument(
        '--model_path', type=str, help=f'Path to the trained model.')
    parser.add_argument(
        '--file', type=str, default='./plots/experiment_results.pickl', help=f'File to save the results in.')
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # ------------
    # model
    # ------------
    model = util.load_model_from_checkpoint(args.model_path)
    checkpoint_path = util.get_checkpoint_path(args.model_path)
    datamodule = util.load_datamodule_for_model(model)
    dataset = model.hparams.dataset

    # ------------
    # file
    # ------------
    if os.path.isfile(args.file):
        print('Loading existing result file')
        with open(args.file, 'rb') as f:
            test_results = pickle.load(f)
    else:
        print('Creating new result file')
        test_results = {}

    if dataset not in test_results:
        test_results[dataset] = {}
    test_results[dataset][model.model_shortname()] = {}
    test_results[dataset][model.model_shortname(
    )]['model_name'] = model.model_name()
    test_results[dataset][model.model_shortname(
    )]['model_shortname'] = model.model_shortname()

    # ------------
    # Run model test script
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)
    test_metrics = trainer.test(
        model=model, ckpt_path=checkpoint_path, datamodule=datamodule)
    for k, v in test_metrics[0].items():
        test_results[dataset][model.model_shortname()][k] = v

    # ------------
    # Record sample-specific metrics
    # ------------
    with torch.no_grad():
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.to(device)
        model.eval()
        metrics = defaultdict(list)
        pixel_metrics = defaultdict(list)
        for i, (x, ys) in enumerate(tqdm(datamodule.test_dataloader(), desc='Collecting sample-individual metrics...')):
            x, ys = util.to_device((x, ys), device)
            assert ys[0].max() <= 1
            y_mean = torch.stack(ys).float().mean(dim=0)

            #
            # Image-wise metrics
            #
            for sample_count in [1, 4, 8, 16]:
                ged, sample_diversity = generalized_energy_distance(
                    model, x, ys, sample_count=sample_count)
                metrics[f"test/ged/{sample_count}"].append(ged)
                metrics[f"test/sample_diversity/{sample_count}"].append(
                    sample_diversity)

                dice = heatmap_dice_loss(
                    model, x, ys, sample_count=sample_count)
                metrics[f"test/diceloss/{sample_count}"].append(dice)

            sample_count = 16
            uncertainty = model.pixel_wise_uncertainty(
                x, sample_cnt=sample_count)
            correl = torch.stack([pearsonr(uncertainty, torch.nn.functional.binary_cross_entropy(
                model.sample_prediction(x).float(),
                ys[torch.randint(len(ys), ())].float(),
                reduction='none')) for _ in range(16)]).mean(dim=0)
            metrics["test/uncertainty_seg_error_correl"].append(correl)
            
            #
            # Pixel-whise metrics
            #
            model_uncertainty = uncertainty.cpu().numpy() # model uncertainty values
            
            # we sample pixel-wise metrics to preserve memory
            PIXELS_PER_IMAGE = 1000
            B = model_uncertainty.shape[0]
            indices = np.moveaxis(np.indices(model_uncertainty.shape[1:]), 0, -1)
            indices = indices.reshape((-1, 3))
            indices = indices[np.random.randint(len(indices), size=PIXELS_PER_IMAGE)]
            # re-add batch dim
            #indices = np.concatenate([np.stack([np.insert(i, 0, b) for i in indices]) for b in range(B)])
            
            y_hat = model.sample_prediction(x) # model thresholded prediction
            annotator_sum = torch.stack(ys).sum(dim=0) # sum of annotator votes
            annotator_cnt = len(ys)
            model_uncertainty = uncertainty # model uncertainty values

            y_mean = torch.stack(ys).float().mean(dim=0)
            annot_uncertainty = util.binary_entropy(y_mean) # annotator uncertainty

            # record conditional uncetainty only if the is consensus
            for b in range(B):
                # iterate over images of the batch
                # sample pixels by indices
                idx = torch.tensor([[b]+list(i) for i in indices])
                y_hat_subsampled = index_select(y_hat, idx)
                annotator_sum_subsampled = index_select(annotator_sum, idx)
                model_uncertainty_subsampled = index_select(model_uncertainty, idx)
                annot_uncertainty_subsampled = index_select(annot_uncertainty, idx)
                pixel_metrics["test/tp_uncertainty"].append(
                    model_uncertainty_subsampled[(y_hat_subsampled == 1) & (annotator_sum_subsampled == annotator_cnt)].cpu().tolist())
                pixel_metrics["test/fp_uncertainty"].append(
                    model_uncertainty_subsampled[(y_hat_subsampled == 1) & (annotator_sum_subsampled == 0)].cpu().tolist())
                pixel_metrics["test/fn_uncertainty"].append(
                    model_uncertainty_subsampled[(y_hat_subsampled == 0) & (annotator_sum_subsampled == annotator_cnt)].cpu().tolist())
                pixel_metrics["test/tn_uncertainty"].append(
                    model_uncertainty_subsampled[(y_hat_subsampled == 0) & (annotator_sum_subsampled == 0)].cpu().tolist())
                pixel_metrics["test/model_uncertainty"].append(model_uncertainty_subsampled.cpu().tolist())
                pixel_metrics["test/annotator_uncertainty"].append(annot_uncertainty_subsampled.cpu().tolist())
                pixel_metrics["test/is_prediction_correct"].append((y_hat_subsampled == (annotator_sum_subsampled+2)//4).tolist())

    # map metrics into lists of floats
    test_results[dataset][model.model_shortname()]['per_sample'] = {}
    for k in metrics:
        test_results[dataset][model.model_shortname()]['per_sample'][k] = torch.cat(
            metrics[k]).cpu().numpy()
    for k in pixel_metrics:
        test_results[dataset][model.model_shortname()]['per_sample'][k] = pixel_metrics[k]

    # ------------
    # save results
    # ------------
    with open(args.file, 'wb') as f:
        pickle.dump(test_results, f)