Ejemplo n.º 1
0
def _multi_view_predict_on(image_pair, image_pair_loader, model, views,
                           hparams, results, per_view_results, out_dir, args):
    from MultiPlanarUNet.utils.fusion import predict_volume, map_real_space_pred
    from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space

    # Set image_pair_loader object with only the given file
    image_pair_loader.images = [image_pair]
    n_classes = hparams["build"]["n_classes"]

    # Load views
    kwargs = hparams["fit"]
    kwargs.update(hparams["build"])
    seq = image_pair_loader.get_sequencer(views=views, **kwargs)

    # Get voxel grid in real space
    voxel_grid_real_space = get_voxel_grid_real_space(image_pair)

    # Prepare tensor to store combined prediction
    d = image_pair.image.shape[:-1]
    combined = np.empty(shape=(len(views), d[0], d[1], d[2], n_classes),
                        dtype=np.float32)
    print("Predicting on brain hyper-volume of shape:", combined.shape)

    # Predict for each view
    for n_view, view in enumerate(views):
        print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), view))
        # for each view, predict on all voxels and map the predictions
        # back into the original coordinate system

        # Sample planes from the image at grid_real_space grid
        # in real space (scanner RAS) coordinates.
        X, y, grid, inv_basis = seq.get_view_from(image_pair.id,
                                                  view,
                                                  n_planes="same+20")

        # Predict on volume using model
        pred = predict_volume(model, X, axis=2, batch_size=seq.batch_size)

        # Map the real space coordiante predictions to nearest
        # real space coordinates defined on voxel grid
        mapped_pred = map_real_space_pred(pred,
                                          grid,
                                          inv_basis,
                                          voxel_grid_real_space,
                                          method="nearest")
        combined[n_view] = mapped_pred

        if not args.no_eval:
            _per_view_evaluation(image_id=image_pair.id,
                                 pred=pred,
                                 true=y,
                                 mapped_pred=mapped_pred,
                                 mapped_true=image_pair.labels,
                                 view=view,
                                 n_classes=n_classes,
                                 results=results,
                                 per_view_results=per_view_results,
                                 out_dir=out_dir,
                                 args=args)
    return combined
Ejemplo n.º 2
0
def predict_and_map(model, seq, image, view, batch_size=None,
                    voxel_grid_real_space=None, targets=None, eval_prob=1.0,
                    n_planes='same+20', torch=False):
    """


    Args:
        model:
        seq:
        image:
        view:
        batch_size:
        voxel_grid_real_space:
        targets:
        n_planes:
        torch:

    Returns:

    """

    # Sample planes from the image at grid_real_space grid
    # in real space (scanner RAS) coordinates.
    X, y, grid, inv_basis = seq.get_view_from(image.id, view, n_planes=n_planes)

    # Predict on volume using model
    bs = seq.batch_size if batch_size is None else batch_size
    if not torch:
        # Normal prediction
        from MultiPlanarUNet.utils.fusion import predict_volume
        pred = predict_volume(model, X, axis=2, batch_size=bs)
    else:
        # Predict using a PyTorch model
        from MultiPlanarUNet.torch.utils.fusion import predict_volume
        pred = predict_volume(model, X, axis=2, batch_size=bs)

    # Map the real space coordiante predictions to nearest
    # real space coordinates defined on voxel grid
    if voxel_grid_real_space is None:
        from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space
        voxel_grid_real_space = get_voxel_grid_real_space(image)

    # Map the predicted volume to real space
    mapped = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space)

    # Print dice scores
    if targets is not None and np.random.rand(1)[0] <= eval_prob:
        print("Computing evaluations...")
        print("View dice scores:   ", dice_all(y, pred.argmax(-1),
                                               ignore_zero=False))
        print("Mapped dice scores: ", dice_all(targets,
                                               mapped.argmax(-1).reshape(-1, 1),
                                               ignore_zero=False))
    else:
        print("-- Skipping evaluation")

    return mapped
Ejemplo n.º 3
0
def _run_fusion_training(sets, logger, hparams, min_val_images, is_validation,
                         views, n_classes, unet, fusion_model_org,
                         fusion_model, early_stopping, fm_batch_size, epochs,
                         eval_prob, fusion_weights_path):

    for _round, _set in enumerate(sets):
        s = "Set %i/%i:\n%s" % (_round + 1, len(sets), _set)
        logger("\n%s" % highlighted(s))

        # Reload data
        images = ImagePairLoader(**hparams["val_data"])
        if len(images) < min_val_images:
            images.add_images(ImagePairLoader(**hparams["train_data"]))

        # Get list of ImagePair objects to run on
        image_set_dict = {m.id: m for m in images if m.id in _set}

        # Fetch points from the set images
        points_collection = []
        targets_collection = []
        N_im = len(image_set_dict)
        for num_im, image_id in enumerate(list(image_set_dict.keys())):
            logger("")
            logger(
                highlighted("(%i/%i) Running on %s (%s)" %
                            (num_im + 1, N_im, image_id,
                             "val" if is_validation[image_id] else "train")))

            # Set the current ImagePair
            image = image_set_dict[image_id]
            images.images = [image]

            # Load views
            kwargs = hparams["fit"]
            kwargs.update(hparams["build"])
            seq = images.get_sequencer(views=views, **kwargs)

            # Get voxel grid in real space
            voxel_grid_real_space = get_voxel_grid_real_space(image)

            # Get array to store predictions across all views
            targets = image.labels.reshape(-1, 1)
            points = np.empty(shape=(len(targets), len(views), n_classes),
                              dtype=np.float32)
            points.fill(np.nan)

            # Predict on all views
            for k, v in enumerate(views):
                print("\n%s" % "View: %s" % v)
                points[:, k, :] = predict_and_map(
                    model=unet,
                    seq=seq,
                    image=image,
                    view=v,
                    voxel_grid_real_space=voxel_grid_real_space,
                    n_planes='same+20',
                    targets=targets,
                    eval_prob=eval_prob).reshape(-1, n_classes)

            # Clean up a bit
            del image_set_dict[image_id]
            del image  # Should be GC at this point anyway

            # add to collections
            points_collection.append(points)
            targets_collection.append(targets)

        # Stack points into one matrix
        logger("Stacking points...")
        X, y = stack_collections(points_collection, targets_collection)

        # Shuffle train
        print("Shuffling points...")
        X, y = shuffle(X, y)

        print("Getting validation set...")
        val_ind = int(0.20 * X.shape[0])
        X_val, y_val = X[:val_ind], y[:val_ind]
        X, y = X[val_ind:], y[val_ind:]

        # Prepare dice score callback for validation data
        val_cb = ValDiceScores((X_val, y_val), n_classes, 50000, logger)

        # Callbacks
        cbs = [
            val_cb,
            CSVLogger(filename="logs/fusion_training.csv",
                      separator=",",
                      append=True),
            PrintLayerWeights(fusion_model_org.layers[-1],
                              every=1,
                              first=1000,
                              per_epoch=True,
                              logger=logger)
        ]

        es = EarlyStopping(monitor='val_dice',
                           min_delta=0.0,
                           patience=early_stopping,
                           verbose=1,
                           mode='max')
        cbs.append(es)

        # Start training
        try:
            fusion_model.fit(X,
                             y,
                             batch_size=fm_batch_size,
                             epochs=epochs,
                             callbacks=cbs,
                             verbose=1)
        except KeyboardInterrupt:
            pass
        fusion_model_org.save_weights(fusion_weights_path)
Ejemplo n.º 4
0
def predict_single(image, model, hparams, verbose=1):
    """
    A generic prediction function that sets up a ImagePairLoader object for the
    given image, prepares the image and predicts.

    Note that this function should only be used for convinience in scripts that
    work on single images at a time anyway, as batch-preparing the entire
    ImagePairLoader object prior to prediction is faster.

    NOTE: Only works with iso_live intrp modes at this time
    """
    mode = hparams["fit"]["intrp_style"].lower()
    assert mode in ("iso_live", "iso_live_3d")

    # Prepare image for prediction
    kwargs = hparams["fit"]
    kwargs.update(hparams["build"])

    # Set verbose memory
    verb_mem = kwargs["verbose"]
    kwargs["verbose"] = verbose

    # Create a ImagePairLoader with only the given file
    from MultiPlanarUNet.image import ImagePairLoader
    image_pair_loader = ImagePairLoader(predict_mode=True,
                                        single_file_mode=True,
                                        no_log=bool(verbose))
    image_pair_loader.add_image(image)

    # Get N classes
    n_classes = kwargs["n_classes"]

    if mode == "iso_live":
        # Add views if SMMV model
        kwargs["views"] = np.load(hparams.project_path + "/views.npz")["arr_0"]

        # Get sequence object
        sequence = image_pair_loader.get_sequencer(**kwargs)

        # Get voxel grid in real space
        voxel_grid_real_space = get_voxel_grid_real_space(image)

        # Prepare tensor to store combined prediction
        d = image.image.shape
        predicted = np.empty(shape=(len(kwargs["views"]), d[0], d[1], d[2], n_classes),
                             dtype=np.float32)
        print("Predicting on brain hyper-volume of shape:", predicted.shape)

        for n_view, v in enumerate(kwargs["views"]):
            print("\nView %i/%i: %s" % (n_view+1, len(kwargs["views"]), v))
            # Sample the volume along the view
            X, y, grid, inv_basis = sequence.get_view_from(image.id, v,
                                                           n_planes="same+20")

            # Predict on volume using model
            pred = predict_volume(model, X, axis=2)

            # Map the real space coordiante predictions to nearest
            # real space coordinates defined on voxel grid
            predicted[n_view] = map_real_space_pred(pred, grid, inv_basis,
                                                    voxel_grid_real_space,
                                                    method="nearest")
    else:
        predicted = pred_3D_iso(model=model, sequence=image_pair_loader.get_sequencer(**kwargs),
                                image=image, extra_boxes="3x", min_coverage=None)

    # Revert verbose mem
    kwargs["verbose"] = verb_mem

    return predicted
Ejemplo n.º 5
0
def entry_func(args=None):

    # Get command line arguments
    args = vars(get_argparser().parse_args(args))
    base_dir = os.path.abspath(args["project_dir"])
    analytical = args["analytical"]
    majority = args["majority"]
    _file = args["f"]
    label = args["l"]
    await_PID = args["wait_for"]
    eval_prob = args["eval_prob"]
    _continue = args["continue"]
    if analytical and majority:
        raise ValueError("Cannot specify both --analytical and --majority.")

    # Get settings from YAML file
    from MultiPlanarUNet.train.hparams import YAMLHParams
    hparams = YAMLHParams(os.path.join(base_dir, "train_hparams.yaml"))

    if not _file:
        try:
            # Data specified from command line?
            data_dir = os.path.abspath(args["data_dir"])

            # Set with default sub dirs
            hparams["test_data"] = {
                "base_dir": data_dir,
                "img_subdir": "images",
                "label_subdir": "labels"
            }
        except (AttributeError, TypeError):
            data_dir = hparams["test_data"]["base_dir"]
    else:
        data_dir = False
    out_dir = os.path.abspath(args["out_dir"])
    overwrite = args["overwrite"]
    predict_mode = args["no_eval"]
    save_input_files = args["save_input_files"]
    no_argmax = args["no_argmax"]
    on_val = args["on_val"]

    # Check if valid dir structures
    validate_folders(base_dir, out_dir, overwrite, _continue)

    # Import all needed modules (folder is valid at this point)
    import numpy as np
    from MultiPlanarUNet.image import ImagePairLoader, ImagePair
    from MultiPlanarUNet.models import FusionModel
    from MultiPlanarUNet.models.model_init import init_model
    from MultiPlanarUNet.utils import await_and_set_free_gpu, get_best_model, \
                                    create_folders, pred_to_class, set_gpu
    from MultiPlanarUNet.logging import init_result_dicts, save_all, load_result_dicts
    from MultiPlanarUNet.evaluate import dice_all
    from MultiPlanarUNet.utils.fusion import predict_volume, map_real_space_pred
    from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space

    # Wait for PID?
    if await_PID:
        from MultiPlanarUNet.utils import await_PIDs
        await_PIDs(await_PID)

    # Set GPU device
    # Fetch GPU(s)
    num_GPUs = args["num_GPUs"]
    force_gpu = args["force_GPU"]
    # Wait for free GPU
    if not force_gpu:
        await_and_set_free_gpu(N=num_GPUs, sleep_seconds=120)
        num_GPUs = 1
    else:
        set_gpu(force_gpu)
        num_GPUs = len(force_gpu.split(","))

    # Read settings from the project hyperparameter file
    n_classes = hparams["build"]["n_classes"]

    # Get views
    views = np.load("%s/views.npz" % base_dir)["arr_0"]

    # Force settings
    hparams["fit"]["max_background"] = 1
    hparams["fit"]["test_mode"] = True
    hparams["fit"]["mix_planes"] = False
    hparams["fit"]["live_intrp"] = False
    if "use_bounds" in hparams["fit"]:
        del hparams["fit"]["use_bounds"]
    del hparams["fit"]["views"]

    if hparams["build"]["out_activation"] == "linear":
        # Trained with logit targets?
        hparams["build"][
            "out_activation"] = "softmax" if n_classes > 1 else "sigmoid"

    # Set ImagePairLoader object
    if not _file:
        data = "test_data" if not on_val else "val_data"
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            **hparams[data])
    else:
        predict_mode = not bool(label)
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            single_file_mode=True)
        image_pair_loader.add_image(ImagePair(_file, label))

    # Put them into a dict and remove from image_pair_loader to gain more control with
    # garbage collection
    all_images = {image.id: image for image in image_pair_loader.images}
    image_pair_loader.images = None
    if _continue:
        all_images = remove_already_predicted(all_images, out_dir)

    # Evaluate?
    if not predict_mode:
        if _continue:
            csv_dir = os.path.join(out_dir, "csv")
            results, detailed_res = load_result_dicts(csv_dir=csv_dir,
                                                      views=views)
        else:
            # Prepare dictionary to store results in pd df
            results, detailed_res = init_result_dicts(views, all_images,
                                                      n_classes)

        # Save to check correct format
        save_all(results, detailed_res, out_dir)

    # Define result paths
    nii_res_dir = os.path.join(out_dir, "nii_files")
    create_folders(nii_res_dir)
    """ Define UNet model """
    model_path = get_best_model(base_dir + "/model")
    unet = init_model(hparams["build"])
    unet.load_weights(model_path, by_name=True)

    if num_GPUs > 1:
        from tensorflow.keras.utils import multi_gpu_model
        n_classes = unet.n_classes
        unet = multi_gpu_model(unet, gpus=num_GPUs)
        unet.n_classes = n_classes

    weights_name = os.path.splitext(os.path.split(model_path)[1])[0]
    if not analytical and not majority:
        # Get Fusion model
        fm = FusionModel(n_inputs=len(views), n_classes=n_classes)

        weights = base_dir + "/model/fusion_weights/%s_fusion_weights.h5" % weights_name
        print("\n[*] Loading weights:\n", weights)

        # Load fusion weights
        fm.load_weights(weights)
        print("\nLoaded weights:\n\n%s\n%s\n---" %
              tuple(fm.layers[-1].get_weights()))

        # Multi-gpu?
        if num_GPUs > 1:
            print("Using multi-GPU model (%i GPUs)" % num_GPUs)
            fm = multi_gpu_model(fm, gpus=num_GPUs)
    """
    Finally predict on the images
    """
    image_ids = sorted(all_images)
    N_images = len(image_ids)
    for n_image, image_id in enumerate(image_ids):
        print("\n[*] (%i/%s) Running on: %s" %
              (n_image + 1, N_images, image_id))

        # Set image_pair_loader object with only the given file
        image = all_images[image_id]
        image_pair_loader.images = [image]

        # Load views
        kwargs = hparams["fit"]
        kwargs.update(hparams["build"])
        seq = image_pair_loader.get_sequencer(views=views, **kwargs)

        # Get voxel grid in real space
        voxel_grid_real_space = get_voxel_grid_real_space(image)

        # Prepare tensor to store combined prediction
        d = image.image.shape[:-1]
        if not majority:
            combined = np.empty(shape=(len(views), d[0], d[1], d[2],
                                       n_classes),
                                dtype=np.float32)
        else:
            combined = np.empty(shape=(d[0], d[1], d[2], n_classes),
                                dtype=np.float32)
        print("Predicting on brain hyper-volume of shape:", combined.shape)

        # Predict for each view
        for n_view, v in enumerate(views):
            print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), v))
            # for each view, predict on all voxels and map the predictions
            # back into the original coordinate system

            # Sample planes from the image at grid_real_space grid
            # in real space (scanner RAS) coordinates.
            X, y, grid, inv_basis = seq.get_view_from(image.id,
                                                      v,
                                                      n_planes="same+20")

            # Predict on volume using model
            pred = predict_volume(unet, X, axis=2, batch_size=seq.batch_size)

            # Map the real space coordiante predictions to nearest
            # real space coordinates defined on voxel grid
            mapped_pred = map_real_space_pred(pred,
                                              grid,
                                              inv_basis,
                                              voxel_grid_real_space,
                                              method="nearest")
            if not majority:
                combined[n_view] = mapped_pred
            else:
                combined += mapped_pred

            if n_classes == 1:
                # Set to background if outside pred domain
                combined[n_view][np.isnan(combined[n_view])] = 0.

            if not predict_mode and np.random.rand() <= eval_prob:
                view_dices = dice_all(y,
                                      pred_to_class(pred,
                                                    img_dims=3,
                                                    has_batch_dim=False),
                                      ignore_zero=False,
                                      n_classes=n_classes,
                                      skip_if_no_y=False)
                mapped_dices = dice_all(image.labels,
                                        pred_to_class(mapped_pred,
                                                      img_dims=3,
                                                      has_batch_dim=False),
                                        ignore_zero=False,
                                        n_classes=n_classes,
                                        skip_if_no_y=False)
                mean_dice = mapped_dices[~np.isnan(mapped_dices)][1:].mean()

                # Print dice scores
                print("View dice scores:   ", view_dices)
                print("Mapped dice scores: ", mapped_dices)
                print("Mean dice (n=%i): " % (len(mapped_dices) - 1),
                      mean_dice)

                # Add to results
                results.loc[image_id, str(v)] = mean_dice
                detailed_res[str(v)][image_id] = mapped_dices[1:]

                # Overwrite with so-far results
                save_all(results, detailed_res, out_dir)
            else:
                print("Skipping evaluation for this view... "
                      "(eval_prob=%.3f, predict_mode=%s)" %
                      (eval_prob, predict_mode))

        if not analytical and not majority:
            # Combine predictions across views using Fusion model
            print("\nFusing views...")
            combined = np.moveaxis(combined, 0, -2).reshape(
                (-1, len(views), n_classes))
            combined = fm.predict(combined, batch_size=10**4,
                                  verbose=1).reshape(
                                      (d[0], d[1], d[2], n_classes))
        elif analytical:
            print("\nFusing views (analytical)...")
            combined = np.sum(combined, axis=0)

        if not no_argmax:
            print("\nComputing majority vote...")
            combined = pred_to_class(combined.squeeze(),
                                     img_dims=3).astype(np.uint8)

        if not predict_mode:
            if no_argmax:
                # MAP only for dice calculation
                c_temp = pred_to_class(combined, img_dims=3).astype(np.uint8)
            else:
                c_temp = combined

            # Calculate combined prediction dice
            dices = dice_all(image.labels,
                             c_temp,
                             n_classes=n_classes,
                             ignore_zero=True,
                             skip_if_no_y=False)
            mean_dice = dices[~np.isnan(dices)].mean()
            detailed_res["MJ"][image_id] = dices

            print("Combined dices: ", dices)
            print("Combined mean dice: ", mean_dice)
            results.loc[image_id, "MJ"] = mean_dice

            # Overwrite with so-far results
            save_all(results, detailed_res, out_dir)

        # Save combined prediction volume as .nii file
        print("Saving .nii files...")
        save_nii_files(combined, image, nii_res_dir, save_input_files)

        # Remove image from dictionary and image_pair_loader to free memory
        del all_images[image_id]
        image_pair_loader.images.remove(image)

    if not predict_mode:
        # Write final results
        save_all(results, detailed_res, out_dir)