Exemple #1
0
def test_angular_neighbors():

    vectors = [[0, 0, 1], [0, 0, 3], [1, 2, 3], [-1, -2, -3]]
    neighbors = angular_neighbors(vectors, 2)
    true_neighbors = np.array([[1, 2], [0, 2], [0, 1], [0, 1]])

    assert_equal(neighbors, true_neighbors)
def test_angular_neighbors():

    vectors = [[0, 0, 1],
               [0, 0, 3],
               [1, 2, 3],
               [-1, -2, -3]]
    neighbors = angular_neighbors(vectors, 2)
    true_neighbors = np.array([[1, 2],
                               [0, 2],
                               [0, 1],
                               [0, 1]])

    assert_equal(neighbors, true_neighbors)
Exemple #3
0
def nlsam_denoise(data,
                  sigma,
                  bvals,
                  bvecs,
                  block_size,
                  mask=None,
                  is_symmetric=False,
                  n_cores=None,
                  split_b0s=False,
                  subsample=True,
                  n_iter=10,
                  b0_threshold=10,
                  dtype=np.float64,
                  use_threading=False,
                  verbose=False,
                  mp_method=None):
    """Main nlsam denoising function which sets up everything nicely for the local
    block denoising.

    Input
    -----------
    data : ndarray
        Input volume to denoise.
    sigma : ndarray
        Noise standard deviation estimation at each voxel.
        Converted to variance internally.
    bvals : 1D array
        the N b-values associated to each of the N diffusion volume.
    bvecs : N x 3 2D array
        the N 3D vectors for each acquired diffusion gradients.
    block_size : tuple, length = data.ndim
        Patch size + number of angular neighbors to process at once as similar data.

    Optional parameters
    -------------------
    mask : ndarray, default None
        Restrict computations to voxels inside the mask to reduce runtime.
    is_symmetric : bool, default False
        If True, assumes that for each coordinate (x, y, z) in bvecs,
        (-x, -y, -z) was also acquired.
    n_cores : int, default None
        Number of processes to use for the denoising. Default is to use
        all available cores.
    split_b0s : bool, default False
        If True and the dataset contains multiple b0s, a different b0 will be used for
        each run of the denoising. If False, the b0s are averaged and the average b0 is used instead.
    subsample : bool, default True
        If True, find the smallest subset of indices required to process each
        dwi at least once.
    n_iter : int, default 10
        Maximum number of iterations for the reweighted l1 solver.
    b0_threshold : int, default 10
        A b-value below b0_threshold will be considered as a b0 image.
    dtype : np.float32 or np.float64, default np.float64
        Precision to use for inner computations. Note that np.float32 should only be used for
        very, very large datasets (that is, your ram starts swappping) as it can lead to numerical precision errors.
    use_threading : bool, default False
        Do not use multiprocessing, but rather rely on the multithreading capabilities of your numerical solvers.
        While this mode is more memory friendly, it is also slower than using the multiprocessing mode (the default).
        Moreover, it also assumes that your blas/lapack/spams library are built with multithreading, so be sure to check
        that your computer is using multiple cores or the algorithm will just take much longer to complete.
    verbose : bool, default False
        print useful messages.
    mp_method : string
        Dispatch method for multiprocessing,

    Output
    -----------
    data_denoised : ndarray
        The denoised dataset
    """

    if verbose:
        logger.setLevel(logging.INFO)

    if mask is None:
        mask = np.ones(data.shape[:-1], dtype=np.bool)

    if data.shape[:-1] != mask.shape:
        raise ValueError(
            'data shape is {}, but mask shape {} is different!'.format(
                data.shape, mask.shape))

    if data.shape[:-1] != sigma.shape:
        raise ValueError(
            'data shape is {}, but sigma shape {} is different!'.format(
                data.shape, sigma.shape))

    if len(block_size) != len(data.shape):
        raise ValueError(
            'Block shape {} and data shape {} are not of the same '
            'length'.format(data.shape, block_size.shape))

    if not ((dtype == np.float32) or (dtype == np.float64)):
        raise ValueError(
            'dtype should be either np.float32 or np.float64, but is {}'.
            format(dtype))

    b0_loc = np.where(bvals <= b0_threshold)[0]
    dwis = np.where(bvals > b0_threshold)[0]
    num_b0s = len(b0_loc)
    variance = sigma**2

    # We also convert bvecs associated with b0s to exactly (0,0,0), which
    # is not always the case when we hack around with the scanner.
    bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs)

    logger.info("Found {} b0s at position {}".format(str(num_b0s),
                                                     str(b0_loc)))

    # Average all b0s if we don't split them in the training set
    if num_b0s > 1 and not split_b0s:
        num_b0s = 1
        data[..., b0_loc] = np.mean(data[..., b0_loc], axis=-1, keepdims=True)

    # Split the b0s in a cyclic fashion along the training data
    # If we only had one, cycle just return b0_loc indefinitely,
    # else we go through all indexes.
    np.random.shuffle(b0_loc)
    split_b0s_idx = cycle(b0_loc)

    # Double bvecs to find neighbors with assumed symmetry if needed
    if is_symmetric:
        logger.info('Data is assumed to be already symmetric.')
        sym_bvecs = bvecs
    else:
        sym_bvecs = np.vstack((bvecs, -bvecs))

    neighbors = angular_neighbors(sym_bvecs,
                                  block_size[-1] - 1) % data.shape[-1]
    neighbors = neighbors[:data.
                          shape[-1]]  # everything was doubled for symmetry

    # Full overlap for dictionary learning
    overlap = np.array(block_size, dtype=np.int16) - 1

    full_indexes = [(dwi, ) + tuple(neighbors[dwi])
                    for dwi in range(data.shape[-1]) if dwi in dwis]

    if subsample:
        indexes = greedy_set_finder(full_indexes)
    else:
        indexes = full_indexes

    # If we have more b0s than indexes, then we have to add a few more blocks since
    # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks.
    if num_b0s > len(indexes):
        the_rest = [rest for rest in full_indexes if rest not in indexes]
        indexes += the_rest[:(num_b0s - len(indexes))]

    if num_b0s > len(indexes):
        error = (
            'Seems like you still have more b0s {} than available blocks {},'
            ' either average them or deactivate subsampling.'.format(
                num_b0s, len(indexes)))
        raise ValueError(error)

    b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + 1, ))
    data_denoised = np.zeros(data.shape, np.float32)
    divider = np.zeros(data.shape[-1])

    # Put all idx + b0 in this array in each iteration
    to_denoise = np.empty(data.shape[:-1] + (block_size[-1] + 1, ),
                          dtype=dtype)

    for i, idx in enumerate(indexes, start=1):
        b0_loc = tuple((next(split_b0s_idx), ))
        to_denoise[..., 0] = data[..., b0_loc].squeeze()
        to_denoise[..., 1:] = data[..., idx]
        divider[list(b0_loc + idx)] += 1

        logger.info('Now denoising volumes {} / block {} out of {}.'.format(
            b0_loc + idx, i, len(indexes)))

        data_denoised[..., b0_loc + idx] += local_denoise(
            to_denoise,
            b0_block_size,
            overlap,
            variance,
            n_iter=n_iter,
            mask=mask,
            dtype=dtype,
            n_cores=n_cores,
            use_threading=use_threading,
            verbose=verbose,
            mp_method=mp_method)

    data_denoised /= divider
    return data_denoised
def get_global_D(datasets,
                 outfilename,
                 block_size,
                 ncores=None,
                 batchsize=32,
                 niter=500,
                 use_std=False,
                 positivity=False,
                 fit_intercept=True,
                 center=True,
                 b0_threshold=20,
                 split_b0s=True,
                 **kwargs):

    # get the data shape so we can preallocate some arrays
    # we also have to assume all datasets have the same 3D shape obviously
    shape = nib.load(datasets[0]['data']).header.get_data_shape()

    # if len(block_size) < len(shape):
    #     # In the event that we only give out a 3D size and that we have b0s with different TE/TR, we remove those volumes
    #     # Therefore, we need to look at the total number of kept volumes, rather than the shape, to specify the last dimension properly
    #     last_shape = get_indexer(datasets[0]).sum() - 1  # We subtract 1 because block_size adds one b0 to it later down
    #     current_block_size = block_size + (last_shape,)
    #     print('Using full 4D stuff')
    # else:
    current_block_size = block_size

    n_atoms = int(np.prod(current_block_size) * 2)
    b0_block_size = tuple(current_block_size[:-1]) + (
        (current_block_size[-1] + 1, ))
    overlap = b0_block_size
    to_denoise = np.empty(shape[:-1] + (current_block_size[-1] + 1, ),
                          dtype=np.float32)

    train_list = []
    variance_large = []

    for filename in datasets:

        print('Now feeding dataset {}'.format(filename['data']))

        # indexer = get_indexer(filename)
        mask = nib.load(
            filename['mask']).get_fdata(caching='unchanged').astype(np.bool)
        data = nib.load(filename['data']).get_fdata(
            caching='unchanged').astype(np.float32) * mask[..., None]
        # data = data[..., indexer]
        bvals = np.loadtxt(filename['bval'])
        bvecs = np.loadtxt(filename['bvec'])

        if np.shape(bvecs)[0] == 3:
            bvecs = bvecs.T

        b0_loc = np.where(bvals <= b0_threshold)[0]
        dwis = np.where(bvals > b0_threshold)[0]
        num_b0s = len(b0_loc)

        # We also convert bvecs associated with b0s to exactly (0,0,0), which
        # is not always the case when we hack around with the scanner.
        bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs)

        # Average all b0s if we don't split them in the training set
        if num_b0s > 1 and not split_b0s:
            num_b0s = 1
            data[..., b0_loc] = np.mean(data[..., b0_loc],
                                        axis=-1,
                                        keepdims=True)

        # Split the b0s in a cyclic fashion along the training data
        # If we only had one, cycle just return b0_loc indefinitely,
        # else we go through all indexes.
        np.random.shuffle(b0_loc)
        split_b0s_idx = cycle(b0_loc)
        sym_bvecs = np.vstack((bvecs, -bvecs))

        neighbors = angular_neighbors(
            sym_bvecs, current_block_size[-1] - 1) % data.shape[-1]
        neighbors = neighbors[:data.
                              shape[-1]]  # everything was doubled for symmetry

        full_indexes = [(dwi, ) + tuple(neighbors[dwi])
                        for dwi in range(data.shape[-1]) if dwi in dwis]
        indexes = greedy_set_finder(full_indexes)

        # If we have more b0s than indexes, then we have to add a few more blocks since
        # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks.
        if num_b0s > len(indexes):
            the_rest = [rest for rest in full_indexes if rest not in indexes]
            indexes += the_rest[:(num_b0s - len(indexes))]

        if num_b0s > len(indexes):
            error = (
                'Seems like you still have more b0s {} than available blocks {},'
                ' either average them or deactivate subsampling.'.format(
                    num_b0s, len(indexes)))
            raise ValueError(error)

        # whole global centering
        if center:
            data -= data.mean(axis=-1, keepdims=True)

        for i, idx in enumerate(indexes):
            b0_loc = tuple((next(split_b0s_idx), ))

            # if we mix datasets, then we may need to change the storage array size
            if to_denoise.shape[:-1] != data.shape[:-1]:
                del to_denoise
                to_denoise = np.empty(data.shape[:-1] +
                                      (current_block_size[-1] + 1, ),
                                      dtype=np.float32)

            to_denoise[..., 0] = data[..., b0_loc].squeeze()
            to_denoise[..., 1:] = data[..., idx]

            patches = extract_patches(to_denoise, b0_block_size, overlap)
            axis = tuple(range(patches.ndim // 2, patches.ndim))
            mask_patch = np.sum(patches > 0,
                                axis=axis) > np.prod(b0_block_size) // 2
            patches = patches[mask_patch].reshape(-1, np.prod(b0_block_size))

            if use_std:
                try:
                    variance = nib.load(filename['std']).get_data()**2 * mask
                    variance = np.broadcast_to(variance[..., None], data.shape)
                    variance = extract_patches(variance, b0_block_size,
                                               overlap)
                    axis = tuple(range(variance.ndim // 2, variance.ndim))
                    variance = np.median(variance,
                                         axis=axis)[mask_patch].ravel()
                    print('variance shape', variance.shape)
                except IOError:
                    print('Volume {} not found!'.format(filename['std']))
                    variance = [None]
            else:
                variance = [None]

            # check to build with np.r_ the whole list from stringnames instead
            train_list += [patches]
            variance_large += list(variance)
            # train_data.extend(patches)
            # variance_large.extend(variance)

        print('train', len(train_list), data.shape, b0_block_size, overlap,
              patches.shape)

        del data, mask, patches, variance

    print('Fed everything in')

    lengths = [l.shape[0] for l in train_list]
    train_data = np.empty((np.sum(lengths), np.prod(b0_block_size)))
    print(train_data.shape)

    step = 0
    for i in range(len(train_list)):
        length = lengths[i]
        idx = slice(step, step + length)
        train_data[idx] = train_list[i].reshape(-1, np.prod(b0_block_size))
        step += length

    del train_list

    # if center:
    #     train_data -= train_data.mean(axis=1, keepdims=True)

    # we have variance as a N elements list - so check one element to see if it's an array
    if variance_large[0] is not None:
        variance_large = np.asarray(variance_large).ravel()
    else:
        variance_large = None

    savename = 'Dic_' + outfilename + '_size_{}.npy'.format(
        block_size).replace(' ', '')

    D = online_DL(train_data,
                  ncores=ncores,
                  positivity=positivity,
                  fit_intercept=fit_intercept,
                  standardize=True,
                  nlambdas=100,
                  niter=niter,
                  batchsize=batchsize,
                  n_atoms=n_atoms,
                  variance=variance_large,
                  progressbar=True,
                  disable_mkl=True,
                  saveback=savename,
                  use_joblib=False)

    return D
Exemple #5
0
def nlsam_denoise(data,
                  sigma,
                  bvals,
                  bvecs,
                  block_size,
                  mask=None,
                  is_symmetric=False,
                  n_cores=None,
                  subsample=True,
                  n_iter=10,
                  b0_threshold=10,
                  verbose=False,
                  mp_method=None):
    """Main nlsam denoising function which sets up everything nicely for the local
    block denoising.

    Input
    -----------
    data : ndarray
        Input volume to denoise.
    sigma : ndarray
        Noise standard deviation estimation at each voxel.
        Converted to variance internally.
    bvals : 1D array
        the N b-values associated to each of the N diffusion volume.
    bvecs : N x 3 2D array
        the N 3D vectors for each acquired diffusion gradients.
    block_size : tuple, length = data.ndim
        Patch size + number of angular neighbors to process at once as similar data.

    Optional parameters
    -------------------
    mask : ndarray, default None
        Restrict computations to voxels inside the mask to reduce runtime.
    is_symmetric : bool, default False
        If True, assumes that for each coordinate (x, y, z) in bvecs,
        (-x, -y, -z) was also acquired.
    n_cores : int, default None
        Number of processes to use for the denoising. Default is to use
        all available cores.
    subsample : bool, default True
        If True, find the smallest subset of indices required to process each
        dwi at least once.
    n_iter : int, default 10
        Maximum number of iterations for the reweighted l1 solver.
    b0_threshold : int, default 10
        A b-value below b0_threshold will be considered as a b0 image.
    verbose : bool, default False
        print useful messages.
    mp_method : string
        Dispatch method for multiprocessing,

    Output
    -----------
    data_denoised : ndarray
        The denoised dataset
    """

    if verbose:
        logger.setLevel(logging.INFO)

    if mask is None:
        mask = np.ones(data.shape[:-1], dtype=np.bool)

    if data.shape[:-1] != mask.shape:
        raise ValueError(
            'data shape is {}, but mask shape {} is different!'.format(
                data.shape, mask.shape))

    if data.shape[:-1] != sigma.shape:
        raise ValueError(
            'data shape is {}, but sigma shape {} is different!'.format(
                data.shape, sigma.shape))

    if len(block_size) != len(data.shape):
        raise ValueError(
            'Block shape {} and data shape {} are not of the same '
            'length'.format(data.shape, block_size.shape))

    b0_loc = tuple(np.where(bvals <= b0_threshold)[0])
    num_b0s = len(b0_loc)
    variance = sigma**2
    orig_shape = data.shape

    logger.info("Found {} b0s at position {}".format(str(num_b0s),
                                                     str(b0_loc)))

    # Average multiple b0s, and just use the average for the rest of the script
    # patching them in at the end
    if num_b0s > 1:
        mean_b0 = np.mean(data[..., b0_loc], axis=-1)
        dwis = tuple(np.where(bvals > b0_threshold)[0])
        data = data[..., dwis]
        bvals = np.take(bvals, dwis, axis=0)
        bvecs = np.take(bvecs, dwis, axis=0)

        rest_of_b0s = b0_loc[1:]
        b0_loc = b0_loc[0]

        data = np.insert(data, b0_loc, mean_b0, axis=-1)
        bvals = np.insert(bvals, b0_loc, [0.], axis=0)
        bvecs = np.insert(bvecs, b0_loc, [0., 0., 0.], axis=0)
        b0_loc = tuple([b0_loc])
        num_b0s = 1
    else:
        rest_of_b0s = None

    # Double bvecs to find neighbors with assumed symmetry if needed
    if is_symmetric:
        logger.info('Data is assumed to be already symmetric.')
        sym_bvecs = np.delete(bvecs, b0_loc, axis=0)
    else:
        sym_bvecs = np.vstack(
            (np.delete(bvecs, b0_loc, axis=0), np.delete(-bvecs,
                                                         b0_loc,
                                                         axis=0)))

    neighbors = (angular_neighbors(sym_bvecs, block_size[-1] - num_b0s) %
                 (data.shape[-1] - num_b0s))[:data.shape[-1] - num_b0s]

    # Full overlap for dictionary learning
    overlap = np.array(block_size, dtype=np.int16) - 1
    b0 = np.squeeze(data[..., b0_loc])
    data = np.delete(data, b0_loc, axis=-1)

    indexes = [(i, ) + tuple(neighbors[i]) for i in range(len(neighbors))]

    if subsample:
        indexes = greedy_set_finder(indexes)

    b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + num_b0s, ))

    denoised_shape = data.shape[:-1] + (data.shape[-1] + num_b0s, )
    data_denoised = np.zeros(denoised_shape, np.float32)

    # Put all idx + b0 in this array in each iteration
    to_denoise = np.empty(data.shape[:-1] + (block_size[-1] + 1, ),
                          dtype=np.float64)

    for i, idx in enumerate(indexes):
        dwi_idx = tuple(np.where(idx <= b0_loc, idx, np.array(idx) + num_b0s))
        logger.info('Now denoising volumes {} / block {} out of {}.'.format(
            idx, i + 1, len(indexes)))

        to_denoise[..., 0] = np.copy(b0)
        to_denoise[..., 1:] = data[..., idx]

        data_denoised[...,
                      b0_loc + dwi_idx] += local_denoise(to_denoise,
                                                         b0_block_size,
                                                         overlap,
                                                         variance,
                                                         n_iter=n_iter,
                                                         mask=mask,
                                                         dtype=np.float64,
                                                         n_cores=n_cores,
                                                         verbose=verbose,
                                                         mp_method=mp_method)

    divider = np.bincount(np.array(indexes, dtype=np.int16).ravel())
    divider = np.insert(divider, b0_loc, len(indexes))

    data_denoised = data_denoised[:orig_shape[0], :orig_shape[1], :
                                  orig_shape[2], :orig_shape[3]] / divider

    # Put back the original number of b0s
    if rest_of_b0s is not None:

        b0_denoised = np.squeeze(data_denoised[..., b0_loc])
        data_denoised_insert = np.empty(orig_shape, dtype=np.float32)
        n = 0

        for i in range(orig_shape[-1]):
            if i in rest_of_b0s:
                data_denoised_insert[..., i] = b0_denoised
                n += 1
            else:
                data_denoised_insert[..., i] = data_denoised[..., i - n]

        data_denoised = data_denoised_insert

    return data_denoised
def main():

    if len(sys.argv) == 1:
        usage = 'Usage : path_data path_D scanner to_scanner block_size block_up use_std positivity'
        print(usage)
        sys.exit(1)

    path, path_D, scanner, to_scanner, block_size, block_up, use_std, positivity = sys.argv[
        1:]
    path = os.path.expanduser(path)
    path_D = os.path.expanduser(path_D)

    block_size = literal_eval(block_size)
    block_up = literal_eval(block_up)
    use_std = use_std.lower() == 'true'
    positivity = positivity.lower() == 'true'

    ncores = 100
    # N = 1  # we should find a way to actually estimate that
    bias_correct_std = False
    fix_mean = False
    # reweighting = True
    # fit_intercept = False
    fit_intercept = True
    # center = False
    center = True
    use_crossval = True

    D = np.load(path_D)
    datasets = []
    to_datasets = []
    # print(D.min(), D.max())

    # Allow the all keyword to feed in all the scanners
    # if it's a regular scanner, we put it in a list to parse it properly afterwards
    if scanner.lower() == 'all/st':
        scanners = 'GE/st', 'Prisma/st', 'Connectom/st'
    elif scanner.lower() == 'all/sa':
        scanners = 'Prisma/sa', 'Connectom/sa'
    else:
        scanners = [scanner]

    # print(scanner)
    for root, dirs, files in os.walk(path):
        dirs.sort()
        for name in files:
            for scanner in scanners:
                if (name == 'dwi.nii' or name
                        == 'dwi_fw.nii') and scanner.lower() in root.lower():
                    datasets += [os.path.join(root, name)]

                if (name == 'dwi.nii' or name == 'dwi_fw.nii'
                    ) and to_scanner.lower() in root.lower():
                    to_datasets += [os.path.join(root, name)]

    if len(datasets) != len(to_datasets):
        raise ValueError('Size mismatch between lists! {}, {}'.format(
            len(datasets), len(to_datasets)))

    for dataset, to_dataset in zip(datasets, to_datasets):

        print('Now rebuilding {}'.format(dataset))
        predicted_D = path_D.replace('.npy', '')
        output_filename = dataset.replace(
            '.nii', '_predicted_' + predicted_D + '.nii.gz')
        output_filename = output_filename.replace('.nii.gz', '_recon.nii.gz')

        if center:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_center_')

        if use_crossval:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_with_cv2_')
        else:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_with_aic_')

        if use_std:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_with_std_')

        if positivity:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_pos_')

        if fix_mean:
            output_filename = output_filename.replace('_predicted_',
                                                      '_predicted_meanfix_')

        if not fit_intercept:
            output_filename = output_filename.replace(
                '_predicted_', '_predicted_no_intercept_')

        if os.path.isfile(output_filename):
            print('File already exists! Skipping {}'.format(output_filename))
        else:
            # if True:
            vol = nib.load(dataset)

            # once we have loaded the data, replace the name if we used a fw dataset since everything else will match
            dataset = dataset.replace('_fw', '')
            to_dataset = to_dataset.replace('_fw', '')

            indexer = get_indexer(dataset)
            data = vol.get_data()[..., indexer]

            # affine = vol.affine
            # header = vol.header

            mask = nib.load(dataset.replace('.nii', '_mask.nii.gz')).get_data()
            std = nib.load(dataset.replace('.nii', '_std.nii.gz')).get_data()
            bvals = np.loadtxt(dataset.replace('.nii', '.bval'))[indexer]
            bvecs = np.loadtxt(dataset.replace('.nii', '.bvec'))[indexer]

            # data = data[:, :, 30:36]
            # mask = mask[:, :, 30:36]
            # if use_std:
            #     std = std[:, :, 30:36]

            if center:
                # # data_mean = np.mean(data, axis=-1, keepdims=True)
                # with warnings.catch_warnings():
                #     warnings.filterwarnings('ignore')
                #     data_mean = np.nanmean(np.where(data != 0, data, np.nan), axis=-1, keepdims=True)

                # # print(np.sum(np.isfinite(data_mean)), np.sum(np.isnan(data_mean)))
                # data_mean[np.isnan(data_mean)] = 0

                # we need to pull down the 3D volume mean for the upsampling part to make sense though
                data_mean = np.nanmean(np.where(data != 0, data, np.nan),
                                       axis=(0, 1, 2),
                                       keepdims=True)
                data -= data_mean
                # print('centering removed {}'.format(data_mean.mean()))
                # 1/0

            if use_std:
                if bias_correct_std:
                    print('this is now disabled somehow :/')
                    # print(std.shape, data.shape)
                    # std = np.broadcast_to(std[..., None], data.shape)
                    # mask_4D = np.broadcast_to(mask[..., None], data.shape)
                    # std = np.median(corrected_sigma(data, std, mask_4D, N), axis=-1)
                variance = std**2
            else:
                std = None
                variance = None

            # number of blocks is without b0, so we +1 the last dimension
            if len(block_up) < data.ndim:
                # last_size = int(D.shape[0] / np.prod(block_size))
                current_block_size = block_size[:-1] + (block_size[-1] + 1, )
                current_block_up = block_up + (block_size[-1] + 1, )
            else:
                current_block_size = block_size[:-1] + (block_size[-1] + 1, )
                current_block_up = block_up[:-1] + (block_up[-1] + 1, )

            print('Output filename is {}'.format(output_filename))
            print(D.shape, current_block_size, current_block_up)

            factor = np.divide(current_block_up, current_block_size)

            split_b0s = True
            b0_threshold = 20
            b0_loc = np.where(bvals <= b0_threshold)[0]
            dwis = np.where(bvals > b0_threshold)[0]
            num_b0s = len(b0_loc)

            # We also convert bvecs associated with b0s to exactly (0,0,0), which
            # is not always the case when we hack around with the scanner.
            bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs)

            # Average all b0s if we don't split them in the training set
            if num_b0s > 1 and not split_b0s:
                num_b0s = 1
                data[..., b0_loc] = np.mean(data[..., b0_loc],
                                            axis=-1,
                                            keepdims=True)

            # Split the b0s in a cyclic fashion along the training data
            # If we only had one, cycle just return b0_loc indefinitely,
            # else we go through all indexes.
            np.random.shuffle(b0_loc)
            split_b0s_idx = cycle(b0_loc)
            sym_bvecs = np.vstack((bvecs, -bvecs))

            neighbors = angular_neighbors(
                sym_bvecs, current_block_size[-1] - 2) % data.shape[-1]
            neighbors = neighbors[:data.shape[
                -1]]  # everything was doubled for symmetry

            full_indexes = [(dwi, ) + tuple(neighbors[dwi])
                            for dwi in range(data.shape[-1]) if dwi in dwis]
            indexes = greedy_set_finder(full_indexes)

            # If we have more b0s than indexes, then we have to add a few more blocks since
            # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks.
            if num_b0s > len(indexes):
                the_rest = [
                    rest for rest in full_indexes if rest not in indexes
                ]
                indexes += the_rest[:(num_b0s - len(indexes))]

            if num_b0s > len(indexes):
                error = (
                    'Seems like you still have more b0s {} than available blocks {},'
                    ' either average them or deactivate subsampling.'.format(
                        num_b0s, len(indexes)))
                raise ValueError(error)

            # Stuff happens here / we only pimp up dwis, not b0s
            # actually we can't really pimp 4D stuff, SH are there for that
            predicted_size = (int(data.shape[0] * factor[0]),
                              int(data.shape[1] * factor[1]),
                              int(data.shape[2] * factor[2]), data.shape[-1])

            predicted = np.zeros(predicted_size, dtype=np.float32)
            divider = np.zeros(predicted.shape[-1])

            print(data.shape, predicted.shape, factor)

            # Put all idx + b0 in this array in each iteration
            to_denoise = np.empty(data.shape[:-1] + (current_block_size[-1], ),
                                  dtype=np.float64)
            print(to_denoise.shape)

            for i, idx in enumerate(indexes, start=1):
                b0_loc = tuple((next(split_b0s_idx), ))
                to_denoise[..., 0] = data[..., b0_loc].squeeze()
                to_denoise[..., 1:] = data[..., idx]
                divider[list(b0_loc + idx)] += 1

                print('Now denoising volumes {} / block {} out of {}.'.format(
                    b0_loc + idx, i, len(indexes)))
                predicted[..., b0_loc + idx] += rebuild(
                    to_denoise,
                    mask,
                    D,
                    block_size=current_block_size,
                    block_up=current_block_up,
                    ncores=ncores,
                    positivity=positivity,
                    fix_mean=fix_mean,
                    # center=center,
                    fit_intercept=fit_intercept,
                    use_crossval=use_crossval,
                    variance=variance)
                # break
            predicted /= divider

            if center:
                predicted += data_mean
                # el cheapo mask after upsampling
                predicted[predicted == data_mean] = 0

            # clip negatives, which happens at the borders
            predicted.clip(min=0., out=predicted)

            # header voxel size is all screwed up, so replace with the destination header and affine by the matching dataset
            to_affine = nib.load(to_dataset).affine
            to_header = nib.load(to_dataset).header
            # header['pixdim'] = to_header['pixdim']

            imgfile = nib.Nifti1Image(predicted, to_affine, to_header)
            nib.save(imgfile, output_filename)

            # subsample to common bvals/bvecs/tr/te
            to_data = nib.load(to_dataset).get_data()
            to_bvals = np.loadtxt(to_dataset.replace('.nii', '.bval'))
            to_bvecs = np.loadtxt(to_dataset.replace('.nii', '.bvec'))

            to_indexer = get_indexer(to_dataset)
            to_data = to_data[..., to_indexer]
            to_bvals = to_bvals[to_indexer]
            to_bvecs = to_bvecs[to_indexer]

            np.savetxt(to_dataset.replace('dwi.nii', 'dwi_subsample.bval'),
                       to_bvals,
                       fmt='%1.4f')
            np.savetxt(to_dataset.replace('dwi.nii', 'dwi_subsample.bvec'),
                       to_bvecs,
                       fmt='%1.4f')

            # # match to the new bvecs
            # rotated = match_bvecs(predicted, bvals, bvecs, to_bvals, to_bvecs).clip(min=0.)
            # imgfile = nib.Nifti1Image(rotated, to_affine, to_header)
            # nib.save(imgfile, output_filename.replace('predicted', 'rotated'))

            # # match to the new bvecs with nnls
            # rotated_nnls = match_bvecs(predicted, bvals, bvecs, to_bvals, to_bvecs, use_nnls=True)
            # imgfile = nib.Nifti1Image(rotated_nnls, to_affine, to_header)
            # nib.save(imgfile, output_filename.replace('predicted', 'rotated_nnls'))

            del data, to_data, predicted, vol, imgfile, mask, variance, std
            del to_denoise, divider