Beispiel #1
0
def prediction_to_image(prediction,
                        label_map=False,
                        threshold=0.5,
                        labels=None):
    if prediction.shape[0] == 1:
        data = prediction[0]
        if label_map:
            label_map_data = np.zeros(prediction[0, 0].shape, np.int8)
            if labels:
                label = labels[0]
            else:
                label = 1
            label_map_data[data > threshold] = label
            data = label_map_data
    elif prediction.shape[1] > 1:
        if label_map:
            label_map_data = get_prediction_labels(prediction,
                                                   threshold=threshold,
                                                   labels=labels)
            data = label_map_data[0]
        else:
            return multi_class_prediction(prediction)
    else:
        raise RuntimeError("Invalid prediction array shape: {0}".format(
            prediction.shape))
    return get_image(data)
Beispiel #2
0
def run_validation_case_from_image_simple(output_dir,
                                          model,
                                          processed_image,
                                          patch_shape,
                                          image_gt=None,
                                          image_pred=None,
                                          overlap_factor=0.8,
                                          prev_truth_index=None,
                                          prev_truth_size=None,
                                          pred_index=None,
                                          pred_size=None,
                                          is3d=False):
    """

    :param output_dir: folder to save prediction results in
    :param model: loaded trained model
    :param processed_image: as a numpy ndarray, after any processing (scaling, normalization, augmentation) wanted
    :param patch_shape: size of patch on which to run model
    :param image_gt: optional, {0,1} ndarray of ground truth segmentation
    :param image_pred: optional, {0-1} ndarray of predicted segmentation
    :param overlap_factor: amount of overlap between consecutive patches, float 0-1
    :param prev_truth_index: if truth used in prediction - the starting index in input batch (depth), else None
    :param prev_truth_size:if truth used in prediction - amount of truth slices in input batch, else None
    :param pred_index: if other pred used in prediction - the starting index in input batch (depth), else None
    :param pred_size: if other pred used in prediction - amount of pred slices in input batch (depth), else None
    :return: prediction + path to saved location
    """

    prediction, _ = \
        patch_wise_prediction(model=model, data=np.expand_dims(processed_image.squeeze(), 0), overlap_factor=overlap_factor,
                              patch_shape=patch_shape, truth_data=image_gt, prev_truth_index=prev_truth_index,
                              prev_truth_size=prev_truth_size, pred_data=image_pred, pred_index=pred_index,
                              pred_size=pred_size, is3d=is3d)  # [np.newaxis]
    prediction = prediction.squeeze()
    prediction_image = get_image(prediction)

    filename = os.path.join(output_dir, "prediction.nii.gz")
    name_counter = 0
    while os.path.exists(filename):
        name_counter += 1
        filename = os.path.join(output_dir,
                                "prediction_{0}.nii.gz".format(name_counter))
    print("Saving to {}".format(filename))
    prediction_image.to_filename(filename)
    return prediction, filename
def save_nifti(data, path):
    nifti = get_image(data)
    nib.save(nifti, path)
def augment_data(data,
                 truth,
                 data_min,
                 data_max,
                 mask=None,
                 scale_deviation=None,
                 iso_scale_deviation=None,
                 rotate_deviation=None,
                 translate_deviation=None,
                 flip=None,
                 contrast_deviation=None,
                 poisson_noise=None,
                 gaussian_noise=None,
                 speckle_noise=None,
                 piecewise_affine=None,
                 elastic_transform=None,
                 intensity_multiplication_range=None,
                 gaussian_filter=None,
                 coarse_dropout=None,
                 data_range=None,
                 truth_range=None,
                 prev_truth_range=None):
    n_dim = len(truth.shape)
    if scale_deviation:
        scale_factor = random_scale_factor(n_dim, std=scale_deviation)
    else:
        scale_factor = [1, 1, 1]
    if iso_scale_deviation:
        iso_scale_factor = np.random.uniform(1, iso_scale_deviation["max"])
        if random_boolean():
            iso_scale_factor = 1 / iso_scale_factor
        scale_factor[0] *= iso_scale_factor
        scale_factor[1] *= iso_scale_factor
    else:
        iso_scale_factor = None
    if rotate_deviation:
        rotate_factor = random_rotation_angle(n_dim, std=rotate_deviation)
        rotate_factor = np.deg2rad(rotate_factor)
    else:
        rotate_factor = None
    if flip is not None and flip:
        flip_axis = random_flip_dimensions(n_dim, flip)
    else:
        flip_axis = None
    if translate_deviation is not None:
        translate_factor = random_translate_factor(
            n_dim, -np.array(translate_deviation),
            np.array(translate_deviation))
        translate_factor[-1] = np.floor(
            translate_factor[-1])  # z-translate should be int
    else:
        translate_factor = None
    if contrast_deviation is not None:
        val_range = data_max - data_min
        contrast_min_val = data_min + contrast_deviation[
            "min_factor"] * np.random.uniform(-1, 1) * val_range
        contrast_max_val = data_max + contrast_deviation[
            "max_factor"] * np.random.uniform(-1, 1) * val_range
    else:
        contrast_min_val, contrast_max_val = None, None
    if poisson_noise is not None:
        apply_poisson_noise = poisson_noise > np.random.random()
    else:
        apply_poisson_noise = False
    if gaussian_noise is not None:
        apply_gaussian_noise = gaussian_noise["prob"] > np.random.random()
    else:
        apply_gaussian_noise = False
    if speckle_noise is not None:
        apply_speckle_noise = speckle_noise["prob"] > np.random.random()
    else:
        apply_speckle_noise = False

    if gaussian_filter is not None and gaussian_filter["prob"] > 0:
        gaussian_sigma = gaussian_filter["max_sigma"] * np.random.random()
        apply_gaussian = gaussian_filter["prob"] > np.random.random()
    else:
        apply_gaussian, gaussian_sigma = False, None
    if piecewise_affine is not None:
        piecewise_affine_scale = np.random.random() * piecewise_affine["scale"]
    else:
        piecewise_affine_scale = 0
    if (elastic_transform is not None) and (elastic_transform["alpha"] > 0):
        elastic_transform_scale = np.random.random(
        ) * elastic_transform["alpha"]
    else:
        elastic_transform_scale = 0
    if intensity_multiplication_range is not None:
        a, b = intensity_multiplication_range
        intensity_multiplication = np.random.random() * (b - a) + a
    else:
        intensity_multiplication = 1
    if coarse_dropout is not None:
        coarse_dropout_rate = coarse_dropout['rate']
        coarse_dropout_size = coarse_dropout['size_percent']

    image, affine = data, np.eye(4)
    distorted_data, distorted_affine = distort_image(
        image,
        affine,
        flip_axis=flip_axis,
        scale_factor=scale_factor,
        rotate_factor=rotate_factor,
        translate_factor=translate_factor)
    if data_range is None:
        data = resample_to_img(get_image(distorted_data, distorted_affine),
                               image,
                               interpolation="continuous",
                               copy=False,
                               clip=True).get_fdata()
    else:

        data = interpolate_affine_range(distorted_data,
                                        distorted_affine,
                                        data_range,
                                        order=1,
                                        mode='constant',
                                        cval=data_min)

    truth_image, truth_affine = truth, np.eye(4)
    distorted_truth_data, distorted_truth_affine = distort_image(
        truth_image,
        truth_affine,
        flip_axis=flip_axis,
        scale_factor=scale_factor,
        rotate_factor=rotate_factor,
        translate_factor=translate_factor)
    if truth_range is None:
        truth_data = resample_to_img(get_image(distorted_truth_data,
                                               distorted_truth_affine),
                                     truth_image,
                                     interpolation="nearest",
                                     copy=False,
                                     clip=True).get_data()
    else:
        truth_data = interpolate_affine_range(distorted_truth_data,
                                              distorted_truth_affine,
                                              truth_range,
                                              order=0,
                                              mode='constant',
                                              cval=0)

    if prev_truth_range is None:
        prev_truth_data = None
    else:
        prev_truth_data = interpolate_affine_range(distorted_truth_data,
                                                   distorted_truth_affine,
                                                   prev_truth_range,
                                                   order=0,
                                                   mode='constant',
                                                   cval=0)

    if mask is None:
        mask_data = None
    else:
        mask_image, mask_affine = mask, np.eye(4)
        distorted_mask_data, distorted_mask_affine = distort_image(
            mask_image,
            mask_affine,
            flip_axis=flip_axis,
            scale_factor=scale_factor,
            rotate_factor=rotate_factor,
            translate_factor=translate_factor)
        if truth_range is None:
            mask_data = resample_to_img(get_image(distorted_mask_data,
                                                  distorted_mask_affine),
                                        mask_image,
                                        interpolation="nearest",
                                        copy=False,
                                        clip=True).get_data()
        else:
            mask_data = interpolate_affine_range(distorted_mask_data,
                                                 distorted_mask_affine,
                                                 truth_range,
                                                 order=0,
                                                 mode='constant',
                                                 cval=0)

    if piecewise_affine_scale > 0:
        data, truth_data, prev_truth_data, mask_data = apply_piecewise_affine(
            data, truth_data, prev_truth_data, mask_data,
            piecewise_affine_scale)

    if elastic_transform_scale > 0:
        data, truth_data, prev_truth_data, mask_data = apply_elastic_transform(
            data, truth_data, prev_truth_data, mask_data,
            elastic_transform_scale, elastic_transform["sigma"])

    if contrast_deviation is not None:
        data = contrast_augment(data, contrast_min_val, contrast_max_val)

    if intensity_multiplication != 1:
        data = data * intensity_multiplication

    if apply_gaussian:
        data = apply_gaussian_filter(data, gaussian_sigma)

    if apply_poisson_noise:
        data = shot_noise(data)

    if apply_speckle_noise:
        data = add_speckle_noise(data, speckle_noise["sigma"])

    if apply_gaussian_noise:
        data = add_gaussian_noise(data, gaussian_noise["sigma"])

    if coarse_dropout is not None:
        data = apply_coarse_dropout(data,
                                    rate=coarse_dropout_rate,
                                    size_percent=coarse_dropout_size,
                                    per_channel=coarse_dropout["per_channel"])

    return data, truth_data, prev_truth_data, mask_data
Beispiel #5
0
def main(pred_dir,
         config,
         split='test',
         overlap_factor=1,
         preprocess_method=None):
    padding = [16, 16, 8]
    prediction2_dir = os.path.abspath(
        os.path.join(config['base_dir'], 'predictions2', split))
    model = load_old_model(get_last_model_path(config["model_file"]))
    with open(os.path.join(opts.config_dir, 'norm_params.json'), 'r') as f:
        norm_params = json.load(f)

    for sample_folder in glob(os.path.join(pred_dir, split, '*')):
        mask_path = os.path.join(sample_folder, 'prediction.nii.gz')
        truth_path = os.path.join(sample_folder, 'truth.nii.gz')

        subject_id = Path(sample_folder).name
        dest_folder = os.path.join(prediction2_dir, subject_id)
        Path(dest_folder).mkdir(parents=True, exist_ok=True)

        truth = nib.load(truth_path)
        nib.save(truth, os.path.join(dest_folder, Path(truth_path).name))

        mask = nib.load(mask_path)
        mask = process_pred(mask.get_data(), gaussian_std=0.5, threshold=0.5)
        bbox_start, bbox_end = find_bounding_box(mask)
        check_bounding_box(mask, bbox_start, bbox_end)
        if padding is not None:
            bbox_start = np.maximum(bbox_start - padding, 0)
            bbox_end = np.minimum(bbox_end + padding, mask.shape)
        print("BBox: {}-{}".format(bbox_start, bbox_end))

        volume = nib.load(
            os.path.join(original_data_folder, subject_id, 'volume.nii'))
        orig_volume_shape = np.array(volume.get_data().shape)
        volume = cut_bounding_box(volume, bbox_start,
                                  bbox_end).get_data().astype(np.float)

        if preprocess_method is not None:
            print('Applying preprocess by {}...'.format(preprocess_method))
            if preprocess_method == 'window_1_99':
                volume = window_intensities_data(volume)
            else:
                raise Exception(
                    'Unknown preprocess: {}'.format(preprocess_method))

        if norm_params is not None and any(norm_params.values()):
            volume = normalize_data(volume,
                                    mean=norm_params['mean'],
                                    std=norm_params['std'])

        prediction = patch_wise_prediction(model=model,
                                           data=np.expand_dims(volume, 0),
                                           patch_shape=config["patch_shape"] +
                                           [config["patch_depth"]],
                                           overlap_factor=overlap_factor)
        prediction = prediction.squeeze()

        padding2 = list(zip(bbox_start, orig_volume_shape - bbox_end))
        print(padding2)
        prediction = np.pad(prediction,
                            padding2,
                            mode='constant',
                            constant_values=0)
        assert all(
            [s1 == s2 for s1, s2 in zip(prediction.shape, orig_volume_shape)])
        prediction = get_image(prediction)
        nib.save(prediction, os.path.join(dest_folder, Path(mask_path).name))
Beispiel #6
0
    test_data = np.asarray([data_file.root.data[sid]])
    test_truth_data = np.asarray([data_file.root.truth[sid]])

    # Get all predictions
    for i, model in enumerate(step_1_networks):
        config = step_1_configs[i]
        # step 1 - no use of context for prediction
        if config["use_augmentations"]: # TODO - add this key to configs, default False
            prediction = predict_augment(data=test_data, model=model, overlap_factor=overlap_factor,
                                         patch_shape=config["patch_shape"])
        else:
            prediction, _ = \
                patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor,
                                      patch_shape=config["patch_shape"], permute=config["augment"]["permute"])
        prediction = prediction.squeeze()
        prediction_image = get_image(prediction) # NIB format
        filename = os.path.join(output_prediction_dir, subject_id, "prediction_{}.nii.gz".format(i))
        prediction_image.to_filename(filename)

        if i > 0:
            avg_pred += prediction
        else:
            avg_pred = prediction

    # Get the final averaged prediction and write to file
    avg_pred /= n_step_1_models
    nib.save(nib.Nifti1Image(avg_pred, prediction_image.affine, header=prediction_image.header),
             os.path.join(output_prediction_dir, subject_id, f'averaged_prediction.nii'))
    pred_storage[sid] = np.asarray(avg_pred).astype(np.float)

    bin_avg_pred = adapted_postprocess_pred(os.path.join(output_prediction_dir, subject_id, f'averaged_prediction.nii'))
Beispiel #7
0
def run_validation_case(data_index,
                        output_dir,
                        model,
                        data_file,
                        training_modalities,
                        patch_shape,
                        overlap_factor=0,
                        permute=False,
                        prev_truth_index=None,
                        prev_truth_size=None,
                        is3d=False,
                        pred_index=None,
                        pred_size=None,
                        use_augmentations=False,
                        scale_xy=None,
                        resolution_file=''):
    """
    Runs a test case and writes predicted images to file.
    :param data_index: Index from of the list of test cases to get an image prediction from.
    :param output_dir: Where to write prediction images.
    :param output_label_map: If True, will write out a single image with one or more labels. Otherwise outputs
    the (sigmoid) prediction values from the model.
    :param threshold: If output_label_map is set to True, this threshold defines the value above which is 
    considered a positive result and will be assigned a label.  
    :param labels:
    :param training_modalities:
    :param data_file:
    :param scale_xy: wether to attempt and scale the image to main resolution
    :param resolution_file: a file containing a dict of all existing scans' resolutions
    :param model:
    """
    cur_subject_id = data_file.root.subject_ids[data_index].decode()

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_data = np.asarray([data_file.root.data[data_index]])
    #test_data = scale_data(test_data, cur_subject_id, dict_pkl=resolution_file, scale_xy=scale_xy)
    if prev_truth_index is not None:
        test_truth_data = np.asarray([data_file.root.truth[data_index]])
        #test_truth_data = scale_data(test_truth_data, cur_subject_id, dict_pkl=resolution_file, scale_xy=scale_xy)
    else:
        test_truth_data = None

    if pred_index is not None:
        test_pred_data = np.asarray([data_file.root.pred[data_index]])
        #test_pred_data = scale_data(test_pred_data, cur_subject_id, dict_pkl=resolution_file, scale_xy=scale_xy)
    else:
        test_pred_data = None

    # for i, modality in enumerate(training_modalities):
    #     image = get_image(test_data[i])
    #     image.to_filename(os.path.join(output_dir, "data_{0}.nii.gz".format(modality)))
    try:
        test_truth = np.asarray([data_file.root.truth[data_index]])
        test_truth = get_image(
            scale_data(test_truth,
                       cur_subject_id,
                       dict_pkl=resolution_file,
                       scale_xy=scale_xy))
        test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz"))
    except:
        pass

    if patch_shape == test_data.shape[-3:]:
        print("Warning - went in where it wasn't expected!!!!!")
        prediction = predict(model, test_data, permute=permute)
    else:
        if use_augmentations:
            prediction = predict_augment(model=model,
                                         data=test_data,
                                         overlap_factor=overlap_factor,
                                         patch_shape=patch_shape,
                                         permute=permute,
                                         truth_data=test_truth_data,
                                         prev_truth_index=prev_truth_index,
                                         prev_truth_size=prev_truth_size,
                                         pred_data=test_pred_data,
                                         pred_index=pred_index,
                                         pred_size=pred_size,
                                         is3d=is3d)
        else:
            prediction, prediction_var = \
                patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor,
                                      patch_shape=patch_shape, permute=permute, truth_data=test_truth_data,
                                      prev_truth_index=prev_truth_index, prev_truth_size=prev_truth_size,
                                      pred_data=test_pred_data, pred_index=pred_index, pred_size=pred_size, is3d=is3d)
    # if prediction.shape[-1] > 1:
    #     prediction = prediction[..., 1]
    prediction = prediction.squeeze()
    prediction_image = get_image(prediction)
    # prediction_var = prediction_var.squeeze()
    # prediction_var_image = get_image(prediction_var)

    name_counter = 0
    if isinstance(prediction_image, list):
        for i, image in enumerate(prediction_image):
            filename = os.path.join(output_dir,
                                    "prediction_{0}.nii.gz".format(i + 1))
            while os.path.exists(filename):
                name_counter += 1
                filename = os.path.join(
                    output_dir,
                    "prediction_{0}_{1}.nii.gz".format(i + 1, name_counter))
            image.to_filename(filename)
    else:
        filename = os.path.join(output_dir, "prediction.nii.gz")
        # var_fname = os.path.join(output_dir, "prediction_variance.nii.gz")
        while os.path.exists(filename):
            name_counter += 1
            filename = os.path.join(
                output_dir, "prediction_{0}.nii.gz".format(name_counter))
            var_fname = os.path.join(
                output_dir,
                "prediction_variance_{0}.nii.gz".format(name_counter))
        print("Saving to {}".format(filename))
        prediction_image.to_filename(filename)
        # prediction_var_image.to_filename(var_fname)
    return filename
Beispiel #8
0
def multi_class_prediction(prediction, affine):
    prediction_images = []
    for i in range(prediction.shape[1]):
        prediction_images.append(get_image(prediction[0, i]))
    return prediction_images
Beispiel #9
0
def predict_augment(model,
                    data,
                    patch_shape,
                    num_augments=5,
                    overlap_factor=0,
                    batch_size=32,
                    is3d=False,
                    permute=False,
                    truth_data=None,
                    prev_truth_index=None,
                    prev_truth_size=None,
                    pred_data=None,
                    pred_index=None,
                    pred_size=None,
                    specific_slice=None):

    data_max = data.max()
    data_min = data.min()
    data = data.squeeze()
    curr_truth_data = None
    curr_pred_data = None

    order = 2
    cval = np.percentile(data, q=1)

    init_prediction, _ = \
            patch_wise_prediction(model=model, data=data[np.newaxis, ...], overlap_factor=overlap_factor,
                                  batch_size=batch_size, patch_shape=patch_shape, permute=permute,
                                  truth_data=truth_data, prev_truth_index=prev_truth_index,
                                  prev_truth_size=prev_truth_size, pred_data=pred_data, pred_index=pred_index,
                                  pred_size=pred_size, is3d=is3d)

    predictions = [init_prediction.squeeze()]
    for i in range(num_augments):
        ### pixel-wise augmentations - don't need to apply to truth or predictions
        val_range = data_max - data_min
        contrast_min_val = data_min + 0.10 * np.random.uniform(-1,
                                                               1) * val_range
        contrast_max_val = data_max + 0.10 * np.random.uniform(-1,
                                                               1) * val_range
        # curr_data = contrast_augment(data, contrast_min_val, contrast_max_val)
        curr_data = data.copy()

        ### spatial augmentations - need to apply to truth or predictions
        # rotate_factor = np.random.uniform(-30, 30)
        rotate_factor = 0
        to_flip = np.arange(0, 3)[np.random.choice([True, False], size=3)]
        to_transpose = np.random.choice([True, False])
        scale_factors = np.random.normal(1, [0.15, 0.15, 0], 3)

        curr_data = flip_it(curr_data, to_flip)
        curr_affine = scale_image(np.eye(4), scale_factors)
        curr_data = resample_to_img(get_image(curr_data, curr_affine),
                                    get_image(curr_data),
                                    interpolation="continuous",
                                    copy=False,
                                    clip=True).get_fdata()
        curr_truth_data = None
        curr_pred_data = None
        if truth_data is not None:
            curr_truth_data = truth_data.copy().squeeze()
            curr_truth_data = flip_it(curr_truth_data, to_flip)
            curr_truth_data = resample_to_img(get_image(
                curr_truth_data, curr_affine),
                                              get_image(curr_truth_data),
                                              interpolation="continuous",
                                              copy=False,
                                              clip=True).get_fdata()
        if pred_data is not None:
            curr_pred_data = pred_data.copy().squeeze()
            curr_pred_data = flip_it(curr_pred_data, to_flip)
            curr_pred_data = resample_to_img(get_image(curr_pred_data,
                                                       curr_affine),
                                             get_image(curr_pred_data),
                                             interpolation="continuous",
                                             copy=False,
                                             clip=True).get_fdata()

        if to_transpose:
            curr_data = curr_data.transpose([1, 0, 2])
            if truth_data is not None:
                curr_truth_data = curr_truth_data.transpose([1, 0, 2])
            if pred_data is not None:
                curr_pred_data = curr_pred_data.transpose([1, 0, 2])

        curr_data = ndimage.rotate(curr_data,
                                   rotate_factor,
                                   order=order,
                                   reshape=False,
                                   mode='constant',
                                   cval=cval)
        if truth_data is not None:
            curr_truth_data = ndimage.rotate(curr_truth_data,
                                             rotate_factor,
                                             order=order,
                                             reshape=False,
                                             mode='constant',
                                             cval=cval)
            curr_truth_data = curr_truth_data[np.newaxis, ...]
        if pred_data is not None:
            curr_pred_data = ndimage.rotate(curr_pred_data,
                                            rotate_factor,
                                            order=order,
                                            reshape=False,
                                            mode='constant',
                                            cval=cval)
            curr_pred_data = curr_pred_data[np.newaxis, ...]

        curr_prediction, _ = \
            patch_wise_prediction(model=model, data=curr_data[np.newaxis, ...], overlap_factor=overlap_factor,
                                  batch_size=batch_size, patch_shape=patch_shape, permute=permute,
                                  truth_data=curr_truth_data, prev_truth_index=prev_truth_index,
                                  prev_truth_size=prev_truth_size, pred_data=curr_pred_data, pred_index=pred_index,
                                  pred_size=pred_size, is3d=is3d)
        curr_prediction = curr_prediction.squeeze()

        curr_prediction = ndimage.rotate(curr_prediction,
                                         -rotate_factor,
                                         order=0,
                                         reshape=False,
                                         mode='constant',
                                         cval=0)
        if to_transpose:
            curr_prediction = curr_prediction.transpose([1, 0, 2])
        rev_scale_factors = [1 / x for x in scale_factors]
        curr_affine = scale_image(np.eye(4), rev_scale_factors)
        curr_prediction = resample_to_img(get_image(curr_prediction,
                                                    curr_affine),
                                          get_image(curr_prediction),
                                          interpolation="continuous",
                                          copy=False,
                                          clip=True).get_fdata()
        curr_prediction = flip_it(curr_prediction, to_flip)

        predictions += [curr_prediction.squeeze()]

    res = np.stack(predictions, axis=0)
    return res
Beispiel #10
0
def run_validation_case(data_index,
                        output_dir,
                        model,
                        data_file,
                        training_modalities,
                        patch_shape,
                        overlap_factor=0,
                        permute=False,
                        prev_truth_index=None,
                        prev_truth_size=None,
                        use_augmentations=False):
    """
    Runs a test case and writes predicted images to file.
    :param data_index: Index from of the list of test cases to get an image prediction from.
    :param output_dir: Where to write prediction images.
    :param output_label_map: If True, will write out a single image with one or more labels. Otherwise outputs
    the (sigmoid) prediction values from the model.
    :param threshold: If output_label_map is set to True, this threshold defines the value above which is 
    considered a positive result and will be assigned a label.  
    :param labels:
    :param training_modalities:
    :param data_file:
    :param model:
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_data = np.asarray([data_file.root.data[data_index]])
    if prev_truth_index is not None:
        test_truth_data = np.asarray([data_file.root.truth[data_index]])
    else:
        test_truth_data = None

    for i, modality in enumerate(training_modalities):
        image = get_image(test_data[i])
        image.to_filename(
            os.path.join(output_dir, "data_{0}.nii.gz".format(modality)))

    test_truth = get_image(data_file.root.truth[data_index])
    test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz"))

    if patch_shape == test_data.shape[-3:]:
        prediction = predict(model, test_data, permute=permute)
    else:
        if use_augmentations:
            prediction = predict_augment(data=test_data,
                                         model=model,
                                         overlap_factor=overlap_factor,
                                         patch_shape=patch_shape)
        else:
            prediction = \
                patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor,
                                      patch_shape=patch_shape,
                                      truth_data=test_truth_data, prev_truth_index=prev_truth_index,
                                      prev_truth_size=prev_truth_size)[np.newaxis]

    prediction = prediction.squeeze()
    prediction_image = get_image(prediction)
    if isinstance(prediction_image, list):
        for i, image in enumerate(prediction_image):
            image.to_filename(
                os.path.join(output_dir,
                             "prediction_{0}.nii.gz".format(i + 1)))
    else:
        filename = os.path.join(output_dir, "prediction.nii.gz")
        prediction_image.to_filename(filename)
    return filename