Exemplo n.º 1
0
def _multi_view_predict_on(image_pair, image_pair_loader, model,
                           views, hparams, results, per_view_results,
                           out_dir, args):
    from mpunet.utils.fusion import predict_volume, map_real_space_pred
    from mpunet.interpolation.sample_grid import get_voxel_grid_real_space

    # Set image_pair_loader object with only the given file
    image_pair_loader.images = [image_pair]
    n_classes = hparams["build"]["n_classes"]

    # Load views
    kwargs = hparams["fit"]
    kwargs.update(hparams["build"])
    seq = image_pair_loader.get_sequencer(views=views, **kwargs)

    # Get voxel grid in real space
    voxel_grid_real_space = get_voxel_grid_real_space(image_pair)

    # Prepare tensor to store combined prediction
    d = image_pair.image.shape[:-1]
    combined = np.empty(
        shape=(len(views), d[0], d[1], d[2], n_classes),
        dtype=np.float32
    )
    print("Predicting on brain hyper-volume of shape:", combined.shape)

    # Predict for each view
    for n_view, view in enumerate(views):
        print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), view))
        # for each view, predict on all voxels and map the predictions
        # back into the original coordinate system

        # Sample planes from the image at grid_real_space grid
        # in real space (scanner RAS) coordinates.
        X, y, grid, inv_basis = seq.get_view_from(image_pair.id, view,
                                                  n_planes="same+20")

        # Predict on volume using model
        pred = predict_volume(model, X, axis=2, batch_size=seq.batch_size)

        # Map the real space coordiante predictions to nearest
        # real space coordinates defined on voxel grid
        mapped_pred = map_real_space_pred(pred, grid, inv_basis,
                                          voxel_grid_real_space,
                                          method="nearest")
        combined[n_view] = mapped_pred

        if not args.no_eval:
            _per_view_evaluation(image_id=image_pair.id,
                                 pred=pred,
                                 true=y,
                                 mapped_pred=mapped_pred,
                                 mapped_true=image_pair.labels,
                                 view=view,
                                 n_classes=n_classes,
                                 results=results,
                                 per_view_results=per_view_results,
                                 out_dir=out_dir,
                                 args=args)
    return combined
Exemplo n.º 2
0
def predict_and_map(model,
                    seq,
                    image,
                    view,
                    batch_size=None,
                    voxel_grid_real_space=None,
                    targets=None,
                    eval_prob=1.0,
                    n_planes='same+20'):
    """


    Args:
        model:
        seq:
        image:
        view:
        batch_size:
        voxel_grid_real_space:
        targets:
        n_planes:

    Returns:

    """

    # Sample planes from the image at grid_real_space grid
    # in real space (scanner RAS) coordinates.
    X, y, grid, inv_basis = seq.get_view_from(image.id,
                                              view,
                                              n_planes=n_planes)

    # Predict on volume using model
    bs = seq.batch_size if batch_size is None else batch_size
    from mpunet.utils.fusion import predict_volume
    pred = predict_volume(model, X, axis=2, batch_size=bs)

    # Map the real space coordiante predictions to nearest
    # real space coordinates defined on voxel grid
    if voxel_grid_real_space is None:
        from mpunet.interpolation.sample_grid import get_voxel_grid_real_space
        voxel_grid_real_space = get_voxel_grid_real_space(image)

    # Map the predicted volume to real space
    mapped = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space)

    # Print dice scores
    if targets is not None and np.random.rand(1)[0] <= eval_prob:
        print("Computing evaluations...")
        print("View dice scores:   ",
              dice_all(y, pred.argmax(-1), ignore_zero=False))
        print(
            "Mapped dice scores: ",
            dice_all(targets,
                     mapped.argmax(-1).reshape(-1, 1),
                     ignore_zero=False))
    else:
        print("-- Skipping evaluation")

    return mapped