コード例 #1
0
ファイル: predict.py プロジェクト: admshumar/MultiPlanarUNet
def set_gpu_vis(args):
    force_gpu = args.force_GPU
    if not force_gpu:
        # Wait for free GPU
        from MultiPlanarUNet.utils import await_and_set_free_gpu
        await_and_set_free_gpu(N=args.num_GPUs, sleep_seconds=120)
        num_GPUs = args.num_GPUs
    else:
        from MultiPlanarUNet.utils import set_gpu
        set_gpu(force_gpu)
        num_GPUs = len(force_gpu.split(","))
    return num_GPUs
コード例 #2
0
def entry_func(args=None):
    # Get parser
    parser = vars(get_parser().parse_args(args))

    # Get parser arguments
    cv_dir = os.path.abspath(parser["CV_dir"])
    out_dir = os.path.abspath(parser["out_dir"])
    create_folders(out_dir)
    await_PID = parser["wait_for"]
    run_split = parser["run_on_split"]
    start_from = parser["start_from"] or 0
    num_jobs = parser["num_jobs"] or 1

    # GPU settings
    num_GPUs = parser["num_GPUs"]
    force_GPU = parser["force_GPU"]
    ignore_GPU = parser["ignore_GPU"]
    monitor_GPUs_every = parser["monitor_GPUs_every"]

    # User input assertions
    _assert_force_and_ignore_gpus(force_GPU, ignore_GPU)
    if run_split:
        _assert_run_split(start_from, monitor_GPUs_every, num_jobs)

    # Wait for PID?
    if await_PID:
        from MultiPlanarUNet.utils import await_PIDs
        await_PIDs(await_PID)

    # Get file paths
    script = os.path.abspath(parser["script_prototype"])
    hparams = os.path.abspath(parser["hparams_prototype"])
    no_hparams = parser["no_hparams"]

    # Get list of folders of CV data to run on
    cv_folders = get_CV_folders(cv_dir)
    if run_split is not None:
        if run_split < 0 or run_split >= len(cv_folders):
            raise ValueError("--run_on_split should be in range [0-{}], "
                             "got {}".format(len(cv_folders) - 1, run_split))
        cv_folders = [cv_folders[run_split]]
        log_appendix = "_split{}".format(run_split)
    else:
        log_appendix = ""

    # Get a logger object
    logger = Logger(base_path="./",
                    active_file="output" + log_appendix,
                    print_calling_method=False,
                    overwrite_existing=True)

    if force_GPU:
        # Only these GPUs fill be chosen from
        from MultiPlanarUNet.utils import set_gpu
        set_gpu(force_GPU)
    if num_GPUs:
        # Get GPU sets (up to the number of splits)
        gpu_sets = get_free_GPU_sets(num_GPUs, ignore_GPU)[:len(cv_folders)]
    elif not num_jobs or num_jobs < 0:
        raise ValueError("Should specify a number of jobs to run in parallel "
                         "with the --num_jobs flag when using 0 GPUs pr. "
                         "process (--num_GPUs=0 was set).")
    else:
        gpu_sets = ["''"] * parser["num_jobs"]

    # Get process pool, lock and GPU queue objects
    lock = Lock()
    gpu_queue = Queue()
    for gpu in gpu_sets:
        gpu_queue.put(gpu)

    procs = []
    if monitor_GPUs_every is not None and monitor_GPUs_every:
        logger("\nOBS: Monitoring GPU pool every %i seconds\n" %
               monitor_GPUs_every)
        # Start a process monitoring new GPU availability over time
        stop_event = Event()
        t = Process(target=monitor_GPUs,
                    args=(monitor_GPUs_every, gpu_queue, num_GPUs, ignore_GPU,
                          gpu_sets, stop_event))
        t.start()
        procs.append(t)
    else:
        stop_event = None
    try:
        for cv_folder in cv_folders[start_from:]:
            gpus = gpu_queue.get()
            t = Process(target=run_sub_experiment,
                        args=(cv_folder, out_dir, script, hparams, no_hparams,
                              gpus, gpu_queue, lock, logger))
            t.start()
            procs.append(t)
            for t in procs:
                if not t.is_alive():
                    t.join()
    except KeyboardInterrupt:
        for t in procs:
            t.terminate()
    if stop_event is not None:
        stop_event.set()
    for t in procs:
        t.join()
コード例 #3
0
def entry_func(args=None):

    # Get command line arguments
    args = vars(get_argparser().parse_args(args))
    base_dir = os.path.abspath(args["project_dir"])
    _file = args["f"]
    label = args["l"]
    N_extra = args["extra"]
    try:
        N_extra = int(N_extra)
    except ValueError:
        pass

    # Get settings from YAML file
    from MultiPlanarUNet.train.hparams import YAMLHParams
    hparams = YAMLHParams(os.path.join(base_dir, "train_hparams.yaml"))

    # Set strides
    hparams["fit"]["strides"] = args["strides"]

    if not _file:
        try:
            # Data specified from command line?
            data_dir = os.path.abspath(args["data_dir"])

            # Set with default sub dirs
            hparams["test_data"] = {
                "base_dir": data_dir,
                "img_subdir": "images",
                "label_subdir": "labels"
            }
        except (AttributeError, TypeError):
            data_dir = hparams["test_data"]["base_dir"]
    else:
        data_dir = False
    out_dir = os.path.abspath(args["out_dir"])
    overwrite = args["overwrite"]
    predict_mode = args["no_eval"]
    save_only_pred = args["save_only_pred"]

    # Check if valid dir structures
    validate_folders(base_dir, data_dir, out_dir, overwrite)

    # Import all needed modules (folder is valid at this point)
    import numpy as np
    from MultiPlanarUNet.image import ImagePairLoader, ImagePair
    from MultiPlanarUNet.utils import get_best_model, create_folders, \
                                    pred_to_class, await_and_set_free_gpu, set_gpu
    from MultiPlanarUNet.utils.fusion import predict_3D_patches, predict_3D_patches_binary, pred_3D_iso
    from MultiPlanarUNet.logging import init_result_dict_3D, save_all_3D
    from MultiPlanarUNet.evaluate import dice_all
    from MultiPlanarUNet.bin.predict import save_nii_files

    # Fetch GPU(s)
    num_GPUs = args["num_GPUs"]
    force_gpu = args["force_GPU"]
    # Wait for free GPU
    if force_gpu == -1:
        await_and_set_free_gpu(N=num_GPUs, sleep_seconds=240)
    else:
        set_gpu(force_gpu)

    # Read settings from the project hyperparameter file
    dim = hparams["build"]["dim"]
    n_classes = hparams["build"]["n_classes"]
    mode = hparams["fit"]["intrp_style"]

    # Set ImagePairLoader object
    if not _file:
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            **hparams["test_data"])
    else:
        predict_mode = not bool(label)
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            initialize_empty=True)
        image_pair_loader.add_image(ImagePair(_file, label))

    # Put them into a dict and remove from image_pair_loader to gain more control with
    # garbage collection
    all_images = {image.id: image for image in image_pair_loader.images}
    image_pair_loader.images = None
    """ Define UNet model """
    from MultiPlanarUNet.models import model_initializer
    hparams["build"]["batch_size"] = 1
    unet = model_initializer(hparams, False, base_dir)
    model_path = get_best_model(base_dir + "/model")
    unet.load_weights(model_path)

    # Evaluate?
    if not predict_mode:
        # Prepare dictionary to store results in pd df
        results, detailed_res = init_result_dict_3D(all_images, n_classes)

        # Save to check correct format
        save_all_3D(results, detailed_res, out_dir)

    # Define result paths
    nii_res_dir = os.path.join(out_dir, "nii_files")
    create_folders(nii_res_dir)

    image_ids = sorted(all_images)
    for n_image, image_id in enumerate(image_ids):
        print("\n[*] Running on: %s" % image_id)

        # Set image_pair_loader object with only the given file
        image = all_images[image_id]
        image_pair_loader.images = [image]

        seq = image_pair_loader.get_sequencer(n_classes=n_classes,
                                              no_log=True,
                                              **hparams["fit"])

        if mode.lower() == "iso_live_3d":
            pred = pred_3D_iso(model=unet,
                               sequence=seq,
                               image=image,
                               extra_boxes=N_extra,
                               min_coverage=None)
        else:
            # Predict on volume using model
            if n_classes > 1:
                pred = predict_3D_patches(model=unet,
                                          patches=seq,
                                          image=image,
                                          N_extra=N_extra)
            else:
                pred = predict_3D_patches_binary(model=unet,
                                                 patches=seq,
                                                 image_id=image_id,
                                                 N_extra=N_extra)

        if not predict_mode:
            # Get patches for the current image
            y = image.labels

            # Calculate dice score
            print("Mean dice: ", end="", flush=True)
            p = pred_to_class(pred, img_dims=3, has_batch_dim=False)
            dices = dice_all(y, p, n_classes=n_classes, ignore_zero=True)
            mean_dice = dices[~np.isnan(dices)].mean()
            print("Dices: ", dices)
            print("%s (n=%i)" % (mean_dice, len(dices)))

            # Add to results
            results[image_id] = [mean_dice]
            detailed_res[image_id] = dices

            # Overwrite with so-far results
            save_all_3D(results, detailed_res, out_dir)

            # Save results
            save_nii_files(p, image, nii_res_dir, save_only_pred)

        # Remove image from dictionary and image_pair_loader to free memory
        del all_images[image_id]
        image_pair_loader.images.remove(image)

    if not predict_mode:
        # Write final results
        save_all_3D(results, detailed_res, out_dir)
コード例 #4
0
def entry_func(args=None):

    # Project base path
    args = vars(get_argparser().parse_args(args))
    basedir = os.path.abspath(args["project_dir"])
    overwrite = args["overwrite"]
    continue_training = args["continue_training"]
    eval_prob = args["eval_prob"]
    await_PID = args["wait_for"]
    dice_weight = args["dice_weight"]
    print("Fitting fusion model for project-folder: %s" % basedir)

    # Minimum images in validation set before also using training images
    min_val_images = 15

    # Fusion model training params
    epochs = args['epochs']
    fm_batch_size = args["batch_size"]

    # Early stopping params
    early_stopping = args["early_stopping"]

    # Wait for PID?
    if await_PID:
        from MultiPlanarUNet.utils import await_PIDs
        await_PIDs(await_PID)

    # Fetch GPU(s)
    num_GPUs = args["num_GPUs"]
    force_gpu = args["force_GPU"]
    # Wait for free GPU
    if not force_gpu:
        await_and_set_free_gpu(N=num_GPUs, sleep_seconds=120)
        num_GPUs = 1
    else:
        set_gpu(force_gpu)
        num_GPUs = len(force_gpu.split(","))

    # Get logger
    logger = Logger(base_path=basedir,
                    active_file="train_fusion",
                    overwrite_existing=overwrite)

    # Get YAML hyperparameters
    hparams = YAMLHParams(os.path.join(basedir, "train_hparams.yaml"))

    # Get some key settings
    n_classes = hparams["build"]["n_classes"]

    if hparams["build"]["out_activation"] == "linear":
        # Trained with logit targets?
        hparams["build"][
            "out_activation"] = "softmax" if n_classes > 1 else "sigmoid"

    # Get views
    views = np.load("%s/views.npz" % basedir)["arr_0"]
    del hparams["fit"]["views"]

    # Get weights and set fusion (output) path
    weights = get_best_model("%s/model" % basedir)
    weights_name = os.path.splitext(os.path.split(weights)[-1])[0]
    fusion_weights = "%s/model/fusion_weights/" \
                     "%s_fusion_weights.h5" % (basedir, weights_name)
    create_folders(os.path.split(fusion_weights)[0])

    # Log a few things
    log(logger, hparams, views, weights, fusion_weights)

    # Check if exists already...
    if not overwrite and os.path.exists(fusion_weights):
        from sys import exit
        print("\n[*] A fusion weights file already exists at '%s'."
              "\n    Use the --overwrite flag to overwrite." % fusion_weights)
        exit(0)

    # Load validation data
    images = ImagePairLoader(**hparams["val_data"], logger=logger)
    is_validation = {m.id: True for m in images}

    # Define random sets of images to train on simul. (cant be all due
    # to memory constraints)
    image_IDs = [m.id for m in images]

    if len(images) < min_val_images:
        # Pick N random training images
        diff = min_val_images - len(images)
        logger("Adding %i training images to set" % diff)

        # Load the training data and pick diff images
        train = ImagePairLoader(**hparams["train_data"], logger=logger)
        indx = np.random.choice(np.arange(len(train)),
                                diff,
                                replace=diff > len(train))

        # Add the images to the image set set
        train_add = [train[i] for i in indx]
        for m in train_add:
            is_validation[m.id] = False
            image_IDs.append(m.id)
        images.add_images(train_add)

    # Append to length % sub_size == 0
    sub_size = args["images_per_round"]
    rest = int(sub_size * np.ceil(len(image_IDs) / sub_size)) - len(image_IDs)
    if rest:
        image_IDs += list(np.random.choice(image_IDs, rest, replace=False))

    # Shuffle and split
    random.shuffle(image_IDs)
    sets = [
        set(s) for s in np.array_split(image_IDs,
                                       len(image_IDs) / sub_size)
    ]
    assert (contains_all_images(sets, image_IDs))

    # Define fusion model (named 'org' to store reference to orgiginal model if
    # multi gpu model is created below)
    fusion_model_org = FusionModel(n_inputs=len(views),
                                   n_classes=n_classes,
                                   weight=dice_weight,
                                   logger=logger,
                                   verbose=False)

    if continue_training:
        fusion_model_org.load_weights(fusion_weights)
        print("\n[OBS] CONTINUED TRAINING FROM:\n", fusion_weights)

    # Define model
    unet = init_model(hparams["build"], logger)
    print("\n[*] Loading weights: %s\n" % weights)
    unet.load_weights(weights, by_name=True)

    if num_GPUs > 1:
        from tensorflow.keras.utils import multi_gpu_model

        # Set for predictor model
        n_classes = n_classes
        unet = multi_gpu_model(unet, gpus=num_GPUs)
        unet.n_classes = n_classes

        # Set for fusion model
        fusion_model = multi_gpu_model(fusion_model_org, gpus=num_GPUs)
    else:
        fusion_model = fusion_model_org

    # Compile the model
    logger("Compiling...")
    metrics = [
        "sparse_categorical_accuracy", sparse_fg_precision, sparse_fg_recall
    ]
    fusion_model.compile(optimizer=Adam(lr=1e-3),
                         loss=fusion_model_org.loss,
                         metrics=metrics)
    fusion_model_org._log()

    try:
        _run_fusion_training(sets, logger, hparams, min_val_images,
                             is_validation, views, n_classes, unet,
                             fusion_model_org, fusion_model, early_stopping,
                             fm_batch_size, epochs, eval_prob, fusion_weights)
    except KeyboardInterrupt:
        pass
    finally:
        if not os.path.exists(os.path.split(fusion_weights)[0]):
            os.mkdir(os.path.split(fusion_weights)[0])
        # Save fusion model weights
        # OBS: Must be original model if multi-gpu is performed!
        fusion_model_org.save_weights(fusion_weights)
コード例 #5
0
ファイル: predict.py プロジェクト: liujie1977/MultiPlanarUNet
def entry_func(args=None):

    # Get command line arguments
    args = vars(get_argparser().parse_args(args))
    base_dir = os.path.abspath(args["project_dir"])
    analytical = args["analytical"]
    majority = args["majority"]
    _file = args["f"]
    label = args["l"]
    await_PID = args["wait_for"]
    eval_prob = args["eval_prob"]
    _continue = args["continue"]
    if analytical and majority:
        raise ValueError("Cannot specify both --analytical and --majority.")

    # Get settings from YAML file
    from MultiPlanarUNet.train.hparams import YAMLHParams
    hparams = YAMLHParams(os.path.join(base_dir, "train_hparams.yaml"))

    if not _file:
        try:
            # Data specified from command line?
            data_dir = os.path.abspath(args["data_dir"])

            # Set with default sub dirs
            hparams["test_data"] = {
                "base_dir": data_dir,
                "img_subdir": "images",
                "label_subdir": "labels"
            }
        except (AttributeError, TypeError):
            data_dir = hparams["test_data"]["base_dir"]
    else:
        data_dir = False
    out_dir = os.path.abspath(args["out_dir"])
    overwrite = args["overwrite"]
    predict_mode = args["no_eval"]
    save_input_files = args["save_input_files"]
    no_argmax = args["no_argmax"]
    on_val = args["on_val"]

    # Check if valid dir structures
    validate_folders(base_dir, out_dir, overwrite, _continue)

    # Import all needed modules (folder is valid at this point)
    import numpy as np
    from MultiPlanarUNet.image import ImagePairLoader, ImagePair
    from MultiPlanarUNet.models import FusionModel
    from MultiPlanarUNet.models.model_init import init_model
    from MultiPlanarUNet.utils import await_and_set_free_gpu, get_best_model, \
                                    create_folders, pred_to_class, set_gpu
    from MultiPlanarUNet.logging import init_result_dicts, save_all, load_result_dicts
    from MultiPlanarUNet.evaluate import dice_all
    from MultiPlanarUNet.utils.fusion import predict_volume, map_real_space_pred
    from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space

    # Wait for PID?
    if await_PID:
        from MultiPlanarUNet.utils import await_PIDs
        await_PIDs(await_PID)

    # Set GPU device
    # Fetch GPU(s)
    num_GPUs = args["num_GPUs"]
    force_gpu = args["force_GPU"]
    # Wait for free GPU
    if not force_gpu:
        await_and_set_free_gpu(N=num_GPUs, sleep_seconds=120)
        num_GPUs = 1
    else:
        set_gpu(force_gpu)
        num_GPUs = len(force_gpu.split(","))

    # Read settings from the project hyperparameter file
    n_classes = hparams["build"]["n_classes"]

    # Get views
    views = np.load("%s/views.npz" % base_dir)["arr_0"]

    # Force settings
    hparams["fit"]["max_background"] = 1
    hparams["fit"]["test_mode"] = True
    hparams["fit"]["mix_planes"] = False
    hparams["fit"]["live_intrp"] = False
    if "use_bounds" in hparams["fit"]:
        del hparams["fit"]["use_bounds"]
    del hparams["fit"]["views"]

    if hparams["build"]["out_activation"] == "linear":
        # Trained with logit targets?
        hparams["build"][
            "out_activation"] = "softmax" if n_classes > 1 else "sigmoid"

    # Set ImagePairLoader object
    if not _file:
        data = "test_data" if not on_val else "val_data"
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            **hparams[data])
    else:
        predict_mode = not bool(label)
        image_pair_loader = ImagePairLoader(predict_mode=predict_mode,
                                            single_file_mode=True)
        image_pair_loader.add_image(ImagePair(_file, label))

    # Put them into a dict and remove from image_pair_loader to gain more control with
    # garbage collection
    all_images = {image.id: image for image in image_pair_loader.images}
    image_pair_loader.images = None
    if _continue:
        all_images = remove_already_predicted(all_images, out_dir)

    # Evaluate?
    if not predict_mode:
        if _continue:
            csv_dir = os.path.join(out_dir, "csv")
            results, detailed_res = load_result_dicts(csv_dir=csv_dir,
                                                      views=views)
        else:
            # Prepare dictionary to store results in pd df
            results, detailed_res = init_result_dicts(views, all_images,
                                                      n_classes)

        # Save to check correct format
        save_all(results, detailed_res, out_dir)

    # Define result paths
    nii_res_dir = os.path.join(out_dir, "nii_files")
    create_folders(nii_res_dir)
    """ Define UNet model """
    model_path = get_best_model(base_dir + "/model")
    unet = init_model(hparams["build"])
    unet.load_weights(model_path, by_name=True)

    if num_GPUs > 1:
        from tensorflow.keras.utils import multi_gpu_model
        n_classes = unet.n_classes
        unet = multi_gpu_model(unet, gpus=num_GPUs)
        unet.n_classes = n_classes

    weights_name = os.path.splitext(os.path.split(model_path)[1])[0]
    if not analytical and not majority:
        # Get Fusion model
        fm = FusionModel(n_inputs=len(views), n_classes=n_classes)

        weights = base_dir + "/model/fusion_weights/%s_fusion_weights.h5" % weights_name
        print("\n[*] Loading weights:\n", weights)

        # Load fusion weights
        fm.load_weights(weights)
        print("\nLoaded weights:\n\n%s\n%s\n---" %
              tuple(fm.layers[-1].get_weights()))

        # Multi-gpu?
        if num_GPUs > 1:
            print("Using multi-GPU model (%i GPUs)" % num_GPUs)
            fm = multi_gpu_model(fm, gpus=num_GPUs)
    """
    Finally predict on the images
    """
    image_ids = sorted(all_images)
    N_images = len(image_ids)
    for n_image, image_id in enumerate(image_ids):
        print("\n[*] (%i/%s) Running on: %s" %
              (n_image + 1, N_images, image_id))

        # Set image_pair_loader object with only the given file
        image = all_images[image_id]
        image_pair_loader.images = [image]

        # Load views
        kwargs = hparams["fit"]
        kwargs.update(hparams["build"])
        seq = image_pair_loader.get_sequencer(views=views, **kwargs)

        # Get voxel grid in real space
        voxel_grid_real_space = get_voxel_grid_real_space(image)

        # Prepare tensor to store combined prediction
        d = image.image.shape[:-1]
        if not majority:
            combined = np.empty(shape=(len(views), d[0], d[1], d[2],
                                       n_classes),
                                dtype=np.float32)
        else:
            combined = np.empty(shape=(d[0], d[1], d[2], n_classes),
                                dtype=np.float32)
        print("Predicting on brain hyper-volume of shape:", combined.shape)

        # Predict for each view
        for n_view, v in enumerate(views):
            print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), v))
            # for each view, predict on all voxels and map the predictions
            # back into the original coordinate system

            # Sample planes from the image at grid_real_space grid
            # in real space (scanner RAS) coordinates.
            X, y, grid, inv_basis = seq.get_view_from(image.id,
                                                      v,
                                                      n_planes="same+20")

            # Predict on volume using model
            pred = predict_volume(unet, X, axis=2, batch_size=seq.batch_size)

            # Map the real space coordiante predictions to nearest
            # real space coordinates defined on voxel grid
            mapped_pred = map_real_space_pred(pred,
                                              grid,
                                              inv_basis,
                                              voxel_grid_real_space,
                                              method="nearest")
            if not majority:
                combined[n_view] = mapped_pred
            else:
                combined += mapped_pred

            if n_classes == 1:
                # Set to background if outside pred domain
                combined[n_view][np.isnan(combined[n_view])] = 0.

            if not predict_mode and np.random.rand() <= eval_prob:
                view_dices = dice_all(y,
                                      pred_to_class(pred,
                                                    img_dims=3,
                                                    has_batch_dim=False),
                                      ignore_zero=False,
                                      n_classes=n_classes,
                                      skip_if_no_y=False)
                mapped_dices = dice_all(image.labels,
                                        pred_to_class(mapped_pred,
                                                      img_dims=3,
                                                      has_batch_dim=False),
                                        ignore_zero=False,
                                        n_classes=n_classes,
                                        skip_if_no_y=False)
                mean_dice = mapped_dices[~np.isnan(mapped_dices)][1:].mean()

                # Print dice scores
                print("View dice scores:   ", view_dices)
                print("Mapped dice scores: ", mapped_dices)
                print("Mean dice (n=%i): " % (len(mapped_dices) - 1),
                      mean_dice)

                # Add to results
                results.loc[image_id, str(v)] = mean_dice
                detailed_res[str(v)][image_id] = mapped_dices[1:]

                # Overwrite with so-far results
                save_all(results, detailed_res, out_dir)
            else:
                print("Skipping evaluation for this view... "
                      "(eval_prob=%.3f, predict_mode=%s)" %
                      (eval_prob, predict_mode))

        if not analytical and not majority:
            # Combine predictions across views using Fusion model
            print("\nFusing views...")
            combined = np.moveaxis(combined, 0, -2).reshape(
                (-1, len(views), n_classes))
            combined = fm.predict(combined, batch_size=10**4,
                                  verbose=1).reshape(
                                      (d[0], d[1], d[2], n_classes))
        elif analytical:
            print("\nFusing views (analytical)...")
            combined = np.sum(combined, axis=0)

        if not no_argmax:
            print("\nComputing majority vote...")
            combined = pred_to_class(combined.squeeze(),
                                     img_dims=3).astype(np.uint8)

        if not predict_mode:
            if no_argmax:
                # MAP only for dice calculation
                c_temp = pred_to_class(combined, img_dims=3).astype(np.uint8)
            else:
                c_temp = combined

            # Calculate combined prediction dice
            dices = dice_all(image.labels,
                             c_temp,
                             n_classes=n_classes,
                             ignore_zero=True,
                             skip_if_no_y=False)
            mean_dice = dices[~np.isnan(dices)].mean()
            detailed_res["MJ"][image_id] = dices

            print("Combined dices: ", dices)
            print("Combined mean dice: ", mean_dice)
            results.loc[image_id, "MJ"] = mean_dice

            # Overwrite with so-far results
            save_all(results, detailed_res, out_dir)

        # Save combined prediction volume as .nii file
        print("Saving .nii files...")
        save_nii_files(combined, image, nii_res_dir, save_input_files)

        # Remove image from dictionary and image_pair_loader to free memory
        del all_images[image_id]
        image_pair_loader.images.remove(image)

    if not predict_mode:
        # Write final results
        save_all(results, detailed_res, out_dir)