Пример #1
0
 def y_pred_cv(self):
     if self.y_pred_cv_ is None:
         self.y_pred_cv_ = []
         for i, (train, test) in enumerate(self.split_):
             if hasattr(self.classifiers_[i], 'predict') and callable(
                     self.classifiers_[i].predict):
                 result = self.classifiers_[i].predict(
                     get_slice(self.X_, rows=test))
                 if is_pandas(self.y_):
                     result = pd.Series(result,
                                        index=self.y_true_cv[i].index)
                 self.y_pred_cv_.append(result)
             elif hasattr(self.classifiers_[i],
                          'predict_proba') and callable(
                              self.classifiers_[i].predict_proba):
                 proba = self.classifiers_[i].predict_proba(
                     get_slice(self.X_, rows=test))
                 result = self.classes_[np.argmax(proba, axis=1)]
                 if is_pandas(self.y_):
                     result = pd.Series(result,
                                        index=self.y_true_cv[i].index)
                 self.y_pred_cv_.append(result)
             else:
                 raise ValueError(
                     'Base classifier does not have "predict" or "predict_proba" callable.'
                 )
     return self.y_pred_cv_
Пример #2
0
def plateau(k=5):
    factor = 2.3

    mu = np.array([-0.5, 1, 3.5, 7, 6]) - 3
    sigma = np.array([0.05, 0.5, 0.2, 2, 0.3]) * 1.2
    weights = np.array([0.2, 1, 0.4, 1, 3])
    indexes_2d = utils.normalize_to_indexes(low=[-10, -10],
                                            high=[10, 10],
                                            n=100)

    # mu = [0]
    # sigma = [1]
    # weights = [1]
    # indexes_2d = utils.normalize_to_indexes(low=[-4, -4], high=[4, 4], n=100)

    p = support_mixed_gaussian_2d(indexes_2d, mu, sigma, weights)

    # raise p value
    p_raised = p + np.max(p) * factor

    # apply a circular 0--1 function
    for ix, x in enumerate(indexes_2d[0]):
        for iy, y in enumerate(indexes_2d[1]):
            if x * x + y * y > 32:
                p_raised[ix][iy] = 0

    # plot
    #plotting.plot_combined(p_raised, indexes=indexes_2d, k=[3,7,10])

    levels = iso_levels.equi_prob_per_level(p_raised, k=k)
    levels2 = iso_levels.equi_value(p_raised, k=k)

    slice_ = utils.get_slice(p_raised, indexes_2d, 'y', 0)

    fig, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))
    plotting.combined_2d(p_raised,
                         levels2,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[0])
    plotting.combined_2d(p_raised,
                         levels,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[1])
    fig.show()
    return fig
Пример #3
0
def data_file_with_kde(filepath,
                       kernel_bandwidth=None,
                       k=7,
                       usecols=None,
                       index_col=None):

    # get data
    df = pd.read_csv(filepath, index_col=index_col, usecols=usecols)
    data = df.values.transpose()

    # derive pdf
    mykernel = pdf_kernel(data, kernel_bandwidth=kernel_bandwidth)

    # get support
    indexes = utils.normalize_to_indexes(data=data)
    p = support_kde(mykernel, indexes)

    # plot
    #plotting.plot_combined(p, indexes[0], indexes[1], k=[3, 5, 7, 10])

    levels = iso_levels.equi_prob_per_level(p, k=k)
    levels2 = iso_levels.equi_value(p, k=k)

    # get index of max
    max_idx = np.unravel_index(np.argmax(p, axis=None), p.shape)
    slice_ = utils.get_slice(p, indexes, 'y', indexes[1][max_idx[0]])
    #slice_ = utils.get_slice(p, indexes, 'y', 68)

    # print('old embrace ratio: {}'.format(stats.embrace_ratio(levels2, p)))
    # print('new embrace ratio: {}'.format(stats.embrace_ratio(levels, p)))

    fig, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))

    plotting.combined_2d(p,
                         levels2,
                         x=indexes[0],
                         y=indexes[1],
                         slice_=slice_,
                         ax=ax[0])
    plotting.combined_2d(p,
                         levels,
                         x=indexes[0],
                         y=indexes[1],
                         slice_=slice_,
                         ax=ax[1])
    fig.show()
    return fig
Пример #4
0
 def y_proba_cv(self):
     if self.y_proba_cv_ is None:
         self.y_proba_cv_ = []
         for i, (train, test) in enumerate(self.split_):
             if hasattr(self.classifiers_[i], 'predict_proba') and callable(
                     self.classifiers_[i].predict_proba):
                 result = self.classifiers_[i].predict_proba(
                     get_slice(self.X_, rows=test))
                 if is_pandas(self.y_):
                     result = pd.DataFrame(result,
                                           index=self.y_true_cv[i].index)
                 self.y_proba_cv_.append(result)
             else:
                 raise ValueError(
                     'Base classifier does not have "predict_proba" callable.'
                 )
     return self.y_proba_cv_
Пример #5
0
def broad_and_normal_gaussians(k=5):

    # works well!
    # mu = np.array([0, -5.5, 0, 5.5])
    # sigma = np.array([1, 10, 6, 10])
    # weights = np.array([1, 1, 0.5, 1])

    # more complex and still works well
    mu = np.array([[0, 0], [-5.5, -1], [0, 0], [5.5, 2], [-3, 4]])
    sigma = np.array([1, 10, 6, 10, 9])
    sigma = [[[s, 0], [0, s]] for s in sigma]

    weights = np.array([1, 1, 0.5, 1, 1])

    indexes_2d = utils.normalize_to_indexes(low=[-12, -10],
                                            high=[12, 10],
                                            n=100)

    p = support_mixed_gaussian_2d(indexes_2d,
                                  mu,
                                  sigma,
                                  weights,
                                  from_scalar=False)

    levels = iso_levels.equi_prob_per_level(p, k=k)
    levels2 = iso_levels.equi_value(p, k=k)

    slice_ = utils.get_slice(p, indexes_2d, 'y', 0)

    fig, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))
    plotting.combined_2d(p,
                         levels2,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[0])
    plotting.combined_2d(p,
                         levels,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[1])
    fig.show()
    return fig
Пример #6
0
def iris_kde(kernel_bandwidth=None, k=6):

    # get data
    from sklearn import datasets
    iris = datasets.load_iris()
    data = iris.data[:, :2].transpose()

    # derive pdf
    mykernel = pdf_kernel(data, kernel_bandwidth=kernel_bandwidth)

    # get support
    indexes = utils.normalize_to_indexes(data=data)
    p = support_kde(mykernel, indexes)

    # plot
    #plotting.plot_combined(p, indexes[0], indexes[1], k=[3, 5, 10])

    levels = iso_levels.equi_prob_per_level(p, k=k)
    levels2 = iso_levels.equi_value(p, k=k)

    slice_ = utils.get_slice(p, indexes, 'y', 3)
    #slice_ = None

    fig, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))
    plotting.combined_2d(p,
                         levels2,
                         x=indexes[0],
                         y=indexes[1],
                         slice_=slice_,
                         ax=ax[0])
    plotting.combined_2d(p,
                         levels,
                         x=indexes[0],
                         y=indexes[1],
                         slice_=slice_,
                         ax=ax[1])
    fig.show()
    return fig
Пример #7
0
 def y_true_cv(self):
     if self.y_true_cv_ is None:
         self.y_true_cv_ = []
         for train, test in self.split_:
             self.y_true_cv_.append(get_slice(self.y_, rows=test))
     return self.y_true_cv_
Пример #8
0
    def fit(self, X, y, sample_weight=None, **kwargs):
        """Fit the base estimator

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            Training data.

        y : array-like, shape (n_samples,)
            Target values.

        sample_weight : array-like, shape = [n_samples] or None
            Sample weights. If None, then samples are equally weighted.

        kwargs :
            Extra args are passed to the base classifier for fitting.

        Returns
        -------
        self : object
            Returns an instance of self.
        """
        base_X, base_y = X, y
        #X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'],
        #                 force_all_finite=False)
        #X, y = indexable(X, y)
        le = LabelBinarizer().fit(y)
        self.classes_ = le.classes_

        # Check that each cross-validation fold can have at least one
        # example per class
        n_folds = self.cv if isinstance(self.cv, int) \
            else self.cv.n_folds if hasattr(self.cv, "n_folds") else None
        if n_folds and \
                np.any([np.sum(y == class_) < n_folds for class_ in
                        self.classes_]):
            raise ValueError("Requesting %d-fold cross-validation but provided"
                             " less than %d examples for at least one class." %
                             (n_folds, n_folds))

        classifiers = []
        base_estimator = self.base_estimator
        super_class = self.super_class

        if self.cv == "prefit":
            if super_class is not None:
                super_estimator = super_class(base_estimator,
                                              **self.super_params)
                super_fit_parameters = signature(
                    super_estimator.fit).parameters
                if sample_weight is not None and 'sample_weight' in super_fit_parameters:
                    super_estimator.fit(X, y, sample_weight)
                else:
                    super_estimator.fit(X, y)
                classifiers.append(super_estimator)
            else:
                classifiers.append(base_estimator)
            splits = [list(range(len(X))), []]
        else:
            cv = check_cv(self.cv, y, classifier=True)
            fit_parameters = signature(base_estimator.fit).parameters
            estimator_name = type(base_estimator).__name__
            if (sample_weight is not None
                    and "sample_weight" not in fit_parameters):
                warnings.warn("%s does not support sample_weight. Samples"
                              " weights are only used for the calibration"
                              " itself." % estimator_name)
                base_estimator_sample_weight = None
            else:
                if sample_weight is not None:
                    sample_weight = check_array(sample_weight, ensure_2d=False)
                    check_consistent_length(y, sample_weight)
                base_estimator_sample_weight = sample_weight
            for k in kwargs:
                if k not in fit_parameters:
                    warnings.warn('%s does not support %s, dropping.' %
                                  (estimator_name, k))
                    kwargs.pop(k)
            splits = list(cv.split(X, y))
            for i, (train, test) in enumerate(splits):
                this_estimator = clone(base_estimator)
                X_train, y_train, X_test, y_test = get_slice(
                    X, rows=train), get_slice(y, rows=train), get_slice(
                        X, rows=test), get_slice(y, rows=test)

                if callable(self.prefit_callback):
                    if self.prefit_callback(X_train, y_train, X_test, y_test,
                                            i, **self.prefit_params) is False:
                        raise ValueError('Fitting aborted by prefit_params')

                if base_estimator_sample_weight is not None:
                    this_estimator.fit(
                        X_train,
                        y_train,
                        sample_weight=base_estimator_sample_weight[train],
                        **kwargs)
                else:
                    this_estimator.fit(X_train, y_train, **kwargs)

                if super_class is not None:
                    super_params = {**self.super_params}
                    if 'classes' in signature(super_class.__init__).parameters:
                        super_params['classes'] = self.classes_
                    super_estimator = super_class(this_estimator,
                                                  **super_params)

                    super_fit_parameters = signature(
                        super_estimator.fit).parameters
                    if sample_weight is not None and 'sample_weight' in super_fit_parameters:
                        super_estimator.fit(X_test, y_test, sample_weight)
                    else:
                        super_estimator.fit(X_test, y_test)

                    if callable(self.postfit_callback):
                        if self.postfit_callback(X_train, y_train, X_test,
                                                 y_test, i, **
                                                 self.postfit_params) is False:
                            raise ValueError(
                                'Fitting aborted by postfit_params')
                    classifiers.append(super_estimator)
                else:
                    if callable(self.postfit_callback):
                        if self.postfit_callback(X_train, y_train, X_test,
                                                 y_test, i, **
                                                 self.postfit_params) is False:
                            raise ValueError(
                                'Fitting aborted by postfit_params')

                    classifiers.append(this_estimator)

        self.reset_cv()
        self.X_, self.y_ = base_X, base_y
        self.classifiers_ = classifiers
        self.split_ = splits

        return self
Пример #9
0
def basic_idea():
    """Creates plot for initial explanatory and motivating example for paper.

    Also provide some search capabilities, i.e. allows to play with parameters to find exemplary distributions.
    """

    mu = [0, 1, 3.5, 7]
    sigma = [0.05, 0.5, 0.7, 2]
    weights = [0.6, 1, 0.7, 1]

    index_1d = utils.normalize_to_indexes(low=[-1], high=[10], n=2500, d=1)[0]
    p_single_1d = [
        support_gaussian_1d(index_1d, m, s) for m, s in zip(mu, sigma)
    ]
    p_mixture_1d = sum(map(operator.mul, weights, p_single_1d))
    levels = iso_levels.equi_prob_per_level(p_mixture_1d, k=7)
    levels2 = iso_levels.equi_value(p_mixture_1d, k=7)

    # I used this i identify suitable mu, sigma and weights
    # figure = plt.figure(figsize=(9, 4))
    # ax = figure.add_subplot(121)
    # plotting.density(levels, gp1d, ax=ax)
    # ax = figure.add_subplot(122)
    # plotting.density(levels2, gp1d, ax=ax)

    fig1d, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))
    plotting.combined_1d(p_mixture_1d, levels2, index_1d, ax[0])
    plotting.combined_1d(p_mixture_1d, levels, index_1d, ax[1])
    fig1d.show()

    indexes_2d = utils.normalize_to_indexes(low=[-1, -3], high=[10, 3], n=100)
    p_mixture_2d = support_mixed_gaussian_2d(indexes_2d, mu, sigma, weights)

    levels = iso_levels.equi_prob_per_level(p_mixture_2d, k=7)
    levels2 = iso_levels.equi_value(p_mixture_2d, k=7)

    slice_idx = int(len(indexes_2d[1]) / 2)
    slice_val = indexes_2d[1][slice_idx]
    slice_ = utils.get_slice(p_mixture_2d, indexes_2d, 'y', slice_val)

    print('old embrace ratio: {}'.format(
        stats.embrace_ratio(levels2, p_mixture_2d)))
    print('new embrace ratio: {}'.format(
        stats.embrace_ratio(levels, p_mixture_2d)))

    # I used this to identify k=7 as particularly interesting
    #plotting.plot_combined(gp2d, indexes=gindex, k=list(range(2,10)))
    #plotting.plot_combined(gp2d, indexes=gindex, k=[7])

    fig2d, ax = plt.subplots(2, 3, figsize=(3 * 5, 8))

    plotting.combined_2d(p_mixture_2d,
                         levels2,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[0])
    plotting.combined_2d(p_mixture_2d,
                         levels,
                         x=indexes_2d[0],
                         y=indexes_2d[1],
                         slice_=slice_,
                         ax=ax[1])
    fig2d.show()
    return fig1d, fig2d
Пример #10
0
def get_disparity_map(left_img_arr, right_img_arr, edge_offset):
    # Create a new array to hold disparities.
    disparities = np.zeros((len(left_img_arr) - 2 * edge_offset,
                            len(left_img_arr[0]) - 2 * edge_offset))

    # Calculate the number of usable rows/cols, accounting for minimum offset due to window size.
    num_valid_rows = len(left_img_arr) - 1 - edge_offset
    num_valid_cols = len(left_img_arr[0]) - 1 - edge_offset

    for row in range(edge_offset, num_valid_rows):
        # Instead of iterating through every column linearly, we have a queue of columns to iterate.
        columns = []
        # Stores columns from the RIGHT that have been matched, alongside with info about quality.
        matches = {}
        # Initially, fill with all columns.
        columns += list(range(edge_offset, num_valid_cols))

        # Iterating through all columns that need a pair.
        while (len(columns) > 0):
            current_column = columns.pop()

            # Extract a window around the target pixel. We'll be reusing it, so we only pull it once.
            left_window = get_slice(row, current_column, edge_offset,
                                    left_img_arr)

            # Info about the best match.
            max_correlation = 0
            max_correlation_col = None

            # With a fixed left pixel, iterate through every potential match in the right
            for potential_matching_col in range(
                    current_column - 1,
                    max(current_column - ASSUMED_MAX_DISPARITY,
                        edge_offset - 1), -1):
                rightChunk = get_slice(row, potential_matching_col,
                                       edge_offset, right_img_arr)
                correlation = get_correlation(left_window, rightChunk)

                if max_correlation_col is None or correlation > max_correlation:
                    # Now, enforce uniqueness constraint.
                    # If we've found a match better than an existing one, remove existing match.
                    if (not potential_matching_col
                            in matches) or correlation > matches[
                                potential_matching_col]['correlation']:
                        if potential_matching_col in matches:
                            columns.append(
                                matches[potential_matching_col]['left_col'])
                        if max_correlation_col is not None:  #
                            matches.pop(max_correlation_col)  #
                        max_correlation_col = potential_matching_col
                        max_correlation = correlation
                        matches[max_correlation_col] = {
                            'left_col': current_column,
                            'correlation': max_correlation
                        }

            # Assign the disparity of this pixel.
            distance = current_column - max_correlation_col if max_correlation_col is not None else 0
            disparities[row - edge_offset,
                        current_column - edge_offset] = distance
    return disparities
Пример #11
0
def plot_local_interp(mri_sequences, save_dir, interpretation, prediction,
                      classes, slices, axis, dpi, rotate=True, exclude=[],
                      vmax=None):
    """
    Creates the plots with the local interpretation.

    Parameters
    ----------
    mri_sequences: dict
        The MRI images for each sequence.
    save_dir: string
        Path to the directory where results will be saved.
    interpretation: dict
        Dictionary with the volumes of the interpretation
    prediction: numpy array
        3D image with the segmentation
    classes: list
        The classes to be plotted.
    slices: list
        The chosen slices.
    axis: int
        Axis along which we extract the slices.
    dpi: int
        The dpi of the images that will be plotted.
    rotate: boolean
        If the image should be rotated for plot.
    exclude: list
        If some class should be excluded from the plot.
    vmax: int
        The value to take as maximum while plotting.

    Returns
    -------
    numpy array
        The features selected by LIME, converted to the right indexes.
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    n_lines = len(classes) + 1 - len(exclude)
    n_cols = len(mri_sequences.keys()) + 1

    with PdfPages(os.path.join(save_dir, 'local_interpretation.pdf')) as pdf:

        for slice_n in slices:

            slices = get_slice_from_volumes(volumes=mri_sequences,
                                            slice_n=slice_n, axis=axis,
                                            rotate=rotate)
            seg_slice = get_slice(vol=prediction, slice_n=slice_n, axis=axis,
                                  rotate=rotate)

            plt.figure(figsize=(20, 9))
            plt.suptitle('Slice ' + str(slice_n) + ' of the ' +
                         _AXIS[axis] + ' plane')
            sub_id = 1

            for seq in sorted(slices.keys()):
                plt.subplot(n_lines, n_cols, sub_id)
                plt.imshow(slices[seq], cmap='gray')
                plt.title(seq)
                plt.gca().set_axis_off()
                sub_id += 1

            plt.subplot(n_lines, n_cols, sub_id)
            sub_id += 1

            plt.imshow(slices['T1c'], alpha=1.0, cmap='gray')
            plt.gca().set_axis_off()
            plt.hold(True)
            over_cmap = plt.cm.get_cmap('jet')
            over_cmap.set_bad(alpha=0)
            seg_slice[seg_slice <= 0] = np.nan
            plt.imshow(seg_slice, cmap=over_cmap)
            plt.gca().set_axis_off()
            plt.title('Segmentation')

            for c in sorted(interpretation.keys()):
                if c not in exclude:
                    t = 0
                    for seq in sorted(interpretation[c].keys()):
                        plt.subplot(n_lines, n_cols, sub_id)
                        sub_id += 1

                        plt.imshow(slices['T1c'], alpha=1.0, cmap='gray')
                        plt.gca().set_axis_off()
                        plt.hold(True)
                        over_cmap = plt.cm.get_cmap('jet')
                        over_cmap.set_bad(alpha=0)
                        tmp_interp = interpretation[c][seq]
                        tmp_interp = get_slice(vol=tmp_interp,
                                               slice_n=slice_n, axis=axis,
                                               rotate=rotate)
                        tmp_interp[tmp_interp <= 0] = np.nan
                        plt.imshow(tmp_interp, cmap=over_cmap, vmin=0, vmax=1)
                        cbar = plt.colorbar(shrink=.9, format='%0.2f')
                        cbar.ax.tick_params(labelsize=8)

                        t += 1
                        if t == 3:
                            plt.title(classes[c - 1])

                    tmp_seg = deepcopy(seg_slice)
                    tmp_seg[tmp_seg != c] = 0

                    plt.subplot(n_lines, n_cols, sub_id)
                    sub_id += 1
                    plt.imshow(slices['T1c'], alpha=1.0, cmap='gray')
                    plt.gca().set_axis_off()
                    plt.hold(True)
                    over_cmap = plt.cm.get_cmap('jet')
                    over_cmap.set_bad(alpha=0)
                    tmp_seg[tmp_seg <= 0] = np.nan
                    plt.imshow(tmp_seg, cmap=over_cmap, vmin=0,
                               vmax=prediction.max())
                    plt.gca().set_axis_off()

            pdf.savefig()
            plt.close()
Пример #12
0
def run_once(volpath, models):
    '''
    Runs our best model in a provided volume and saves mask,
    In a self contained matter
    '''
    print(
        "\nALPHA VERSION: For this version of this code, the provided volume should return slices on the following way for optimal performance:"
    )
    print("volume[0, :, :] sagital, eyes facing down")
    print("volume[:, 0, :] coronal")
    print("volume[:, :, 0] axial, with eyes facing right\n")
    begin = time.time()
    save_path = volpath + "_e2dhipmask.nii.gz"
    device = get_device()
    orientations = ["sagital", "coronal", "axial"]
    CROP_SHAPE = 160
    slice_transform = Compose([CenterCrop(CROP_SHAPE, CROP_SHAPE), ToTensor()])

    sample_v = normalizeMri(nib.load(volpath).get_fdata().astype(np.float32))
    shape = sample_v.shape
    sum_vol_total = torch.zeros(shape)

    for o, model in models.items():
        model.eval()
        model.to(device)

    print("Performing segmentation...")
    for i, o in enumerate(orientations):
        try:
            slice_shape = myrotate(get_slice(sample_v, 0, o), 90).shape
            for j in range(shape[i]):
                # E2D
                ts = np.zeros((3, slice_shape[0], slice_shape[1]),
                              dtype=np.float32)
                for ii, jj in enumerate(range(j - 1, j + 2)):
                    if jj < 0:
                        jj = 0
                    elif jj == shape[i]:
                        jj = shape[i] - 1

                    if i == 0:
                        ts[ii] = myrotate(sample_v[jj, :, :], 90)
                    elif i == 1:
                        ts[ii] = myrotate(sample_v[:, jj, :], 90)
                    elif i == 2:
                        ts[ii] = myrotate(sample_v[:, :, jj], 90)

                s, _ = slice_transform(ts, ts[1])  # work around, no mask
                s = s.to(device)

                probs = models[o](s.unsqueeze(0))

                cpup = probs.squeeze().detach().cpu()
                finalp = torch.from_numpy(myrotate(
                    cpup.numpy(), -90)).float()  # back to volume orientation

                # Add to final consensus volume, uses original orientation/shape
                if i == 0:
                    toppad = shape[1] // 2 - CROP_SHAPE // 2
                    sidepad = shape[2] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[1] % 2 == 1 else 0
                    sf = 1 if shape[2] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[j, :, :] += pad
                elif i == 1:
                    toppad = shape[0] // 2 - CROP_SHAPE // 2
                    sidepad = shape[2] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[0] % 2 == 1 else 0
                    sf = 1 if shape[2] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[:, j, :] += pad
                elif i == 2:
                    toppad = shape[0] // 2 - CROP_SHAPE // 2
                    sidepad = shape[1] // 2 - CROP_SHAPE // 2

                    tf = 1 if shape[0] % 2 == 1 else 0
                    sf = 1 if shape[1] % 2 == 1 else 0
                    pad = F.pad(
                        finalp,
                        (sidepad + sf, sidepad, toppad, toppad + tf)) / 3

                    sum_vol_total[:, :, j] += pad

        except Exception as e:
            print(
                "Error: {}, make sure your data is ok, please contact author https://github.com/dscarmo"
                .format(e))
            traceback.print_exc()
            quit()

    final_nppred = get_largest_components(sum_vol_total.numpy(), mask_ths=0.5)

    print("Processing took {}s".format(time.time() - begin))
    print("Saving to {}".format(save_path))
    nib.save(nib.nifti1.Nifti1Image(final_nppred, None), save_path)
    return sample_v, final_nppred