def test_angular_neighbors(): vectors = [[0, 0, 1], [0, 0, 3], [1, 2, 3], [-1, -2, -3]] neighbors = angular_neighbors(vectors, 2) true_neighbors = np.array([[1, 2], [0, 2], [0, 1], [0, 1]]) assert_equal(neighbors, true_neighbors)
def nlsam_denoise(data, sigma, bvals, bvecs, block_size, mask=None, is_symmetric=False, n_cores=None, split_b0s=False, subsample=True, n_iter=10, b0_threshold=10, dtype=np.float64, use_threading=False, verbose=False, mp_method=None): """Main nlsam denoising function which sets up everything nicely for the local block denoising. Input ----------- data : ndarray Input volume to denoise. sigma : ndarray Noise standard deviation estimation at each voxel. Converted to variance internally. bvals : 1D array the N b-values associated to each of the N diffusion volume. bvecs : N x 3 2D array the N 3D vectors for each acquired diffusion gradients. block_size : tuple, length = data.ndim Patch size + number of angular neighbors to process at once as similar data. Optional parameters ------------------- mask : ndarray, default None Restrict computations to voxels inside the mask to reduce runtime. is_symmetric : bool, default False If True, assumes that for each coordinate (x, y, z) in bvecs, (-x, -y, -z) was also acquired. n_cores : int, default None Number of processes to use for the denoising. Default is to use all available cores. split_b0s : bool, default False If True and the dataset contains multiple b0s, a different b0 will be used for each run of the denoising. If False, the b0s are averaged and the average b0 is used instead. subsample : bool, default True If True, find the smallest subset of indices required to process each dwi at least once. n_iter : int, default 10 Maximum number of iterations for the reweighted l1 solver. b0_threshold : int, default 10 A b-value below b0_threshold will be considered as a b0 image. dtype : np.float32 or np.float64, default np.float64 Precision to use for inner computations. Note that np.float32 should only be used for very, very large datasets (that is, your ram starts swappping) as it can lead to numerical precision errors. use_threading : bool, default False Do not use multiprocessing, but rather rely on the multithreading capabilities of your numerical solvers. While this mode is more memory friendly, it is also slower than using the multiprocessing mode (the default). Moreover, it also assumes that your blas/lapack/spams library are built with multithreading, so be sure to check that your computer is using multiple cores or the algorithm will just take much longer to complete. verbose : bool, default False print useful messages. mp_method : string Dispatch method for multiprocessing, Output ----------- data_denoised : ndarray The denoised dataset """ if verbose: logger.setLevel(logging.INFO) if mask is None: mask = np.ones(data.shape[:-1], dtype=np.bool) if data.shape[:-1] != mask.shape: raise ValueError( 'data shape is {}, but mask shape {} is different!'.format( data.shape, mask.shape)) if data.shape[:-1] != sigma.shape: raise ValueError( 'data shape is {}, but sigma shape {} is different!'.format( data.shape, sigma.shape)) if len(block_size) != len(data.shape): raise ValueError( 'Block shape {} and data shape {} are not of the same ' 'length'.format(data.shape, block_size.shape)) if not ((dtype == np.float32) or (dtype == np.float64)): raise ValueError( 'dtype should be either np.float32 or np.float64, but is {}'. format(dtype)) b0_loc = np.where(bvals <= b0_threshold)[0] dwis = np.where(bvals > b0_threshold)[0] num_b0s = len(b0_loc) variance = sigma**2 # We also convert bvecs associated with b0s to exactly (0,0,0), which # is not always the case when we hack around with the scanner. bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs) logger.info("Found {} b0s at position {}".format(str(num_b0s), str(b0_loc))) # Average all b0s if we don't split them in the training set if num_b0s > 1 and not split_b0s: num_b0s = 1 data[..., b0_loc] = np.mean(data[..., b0_loc], axis=-1, keepdims=True) # Split the b0s in a cyclic fashion along the training data # If we only had one, cycle just return b0_loc indefinitely, # else we go through all indexes. np.random.shuffle(b0_loc) split_b0s_idx = cycle(b0_loc) # Double bvecs to find neighbors with assumed symmetry if needed if is_symmetric: logger.info('Data is assumed to be already symmetric.') sym_bvecs = bvecs else: sym_bvecs = np.vstack((bvecs, -bvecs)) neighbors = angular_neighbors(sym_bvecs, block_size[-1] - 1) % data.shape[-1] neighbors = neighbors[:data. shape[-1]] # everything was doubled for symmetry # Full overlap for dictionary learning overlap = np.array(block_size, dtype=np.int16) - 1 full_indexes = [(dwi, ) + tuple(neighbors[dwi]) for dwi in range(data.shape[-1]) if dwi in dwis] if subsample: indexes = greedy_set_finder(full_indexes) else: indexes = full_indexes # If we have more b0s than indexes, then we have to add a few more blocks since # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks. if num_b0s > len(indexes): the_rest = [rest for rest in full_indexes if rest not in indexes] indexes += the_rest[:(num_b0s - len(indexes))] if num_b0s > len(indexes): error = ( 'Seems like you still have more b0s {} than available blocks {},' ' either average them or deactivate subsampling.'.format( num_b0s, len(indexes))) raise ValueError(error) b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + 1, )) data_denoised = np.zeros(data.shape, np.float32) divider = np.zeros(data.shape[-1]) # Put all idx + b0 in this array in each iteration to_denoise = np.empty(data.shape[:-1] + (block_size[-1] + 1, ), dtype=dtype) for i, idx in enumerate(indexes, start=1): b0_loc = tuple((next(split_b0s_idx), )) to_denoise[..., 0] = data[..., b0_loc].squeeze() to_denoise[..., 1:] = data[..., idx] divider[list(b0_loc + idx)] += 1 logger.info('Now denoising volumes {} / block {} out of {}.'.format( b0_loc + idx, i, len(indexes))) data_denoised[..., b0_loc + idx] += local_denoise( to_denoise, b0_block_size, overlap, variance, n_iter=n_iter, mask=mask, dtype=dtype, n_cores=n_cores, use_threading=use_threading, verbose=verbose, mp_method=mp_method) data_denoised /= divider return data_denoised
def get_global_D(datasets, outfilename, block_size, ncores=None, batchsize=32, niter=500, use_std=False, positivity=False, fit_intercept=True, center=True, b0_threshold=20, split_b0s=True, **kwargs): # get the data shape so we can preallocate some arrays # we also have to assume all datasets have the same 3D shape obviously shape = nib.load(datasets[0]['data']).header.get_data_shape() # if len(block_size) < len(shape): # # In the event that we only give out a 3D size and that we have b0s with different TE/TR, we remove those volumes # # Therefore, we need to look at the total number of kept volumes, rather than the shape, to specify the last dimension properly # last_shape = get_indexer(datasets[0]).sum() - 1 # We subtract 1 because block_size adds one b0 to it later down # current_block_size = block_size + (last_shape,) # print('Using full 4D stuff') # else: current_block_size = block_size n_atoms = int(np.prod(current_block_size) * 2) b0_block_size = tuple(current_block_size[:-1]) + ( (current_block_size[-1] + 1, )) overlap = b0_block_size to_denoise = np.empty(shape[:-1] + (current_block_size[-1] + 1, ), dtype=np.float32) train_list = [] variance_large = [] for filename in datasets: print('Now feeding dataset {}'.format(filename['data'])) # indexer = get_indexer(filename) mask = nib.load( filename['mask']).get_fdata(caching='unchanged').astype(np.bool) data = nib.load(filename['data']).get_fdata( caching='unchanged').astype(np.float32) * mask[..., None] # data = data[..., indexer] bvals = np.loadtxt(filename['bval']) bvecs = np.loadtxt(filename['bvec']) if np.shape(bvecs)[0] == 3: bvecs = bvecs.T b0_loc = np.where(bvals <= b0_threshold)[0] dwis = np.where(bvals > b0_threshold)[0] num_b0s = len(b0_loc) # We also convert bvecs associated with b0s to exactly (0,0,0), which # is not always the case when we hack around with the scanner. bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs) # Average all b0s if we don't split them in the training set if num_b0s > 1 and not split_b0s: num_b0s = 1 data[..., b0_loc] = np.mean(data[..., b0_loc], axis=-1, keepdims=True) # Split the b0s in a cyclic fashion along the training data # If we only had one, cycle just return b0_loc indefinitely, # else we go through all indexes. np.random.shuffle(b0_loc) split_b0s_idx = cycle(b0_loc) sym_bvecs = np.vstack((bvecs, -bvecs)) neighbors = angular_neighbors( sym_bvecs, current_block_size[-1] - 1) % data.shape[-1] neighbors = neighbors[:data. shape[-1]] # everything was doubled for symmetry full_indexes = [(dwi, ) + tuple(neighbors[dwi]) for dwi in range(data.shape[-1]) if dwi in dwis] indexes = greedy_set_finder(full_indexes) # If we have more b0s than indexes, then we have to add a few more blocks since # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks. if num_b0s > len(indexes): the_rest = [rest for rest in full_indexes if rest not in indexes] indexes += the_rest[:(num_b0s - len(indexes))] if num_b0s > len(indexes): error = ( 'Seems like you still have more b0s {} than available blocks {},' ' either average them or deactivate subsampling.'.format( num_b0s, len(indexes))) raise ValueError(error) # whole global centering if center: data -= data.mean(axis=-1, keepdims=True) for i, idx in enumerate(indexes): b0_loc = tuple((next(split_b0s_idx), )) # if we mix datasets, then we may need to change the storage array size if to_denoise.shape[:-1] != data.shape[:-1]: del to_denoise to_denoise = np.empty(data.shape[:-1] + (current_block_size[-1] + 1, ), dtype=np.float32) to_denoise[..., 0] = data[..., b0_loc].squeeze() to_denoise[..., 1:] = data[..., idx] patches = extract_patches(to_denoise, b0_block_size, overlap) axis = tuple(range(patches.ndim // 2, patches.ndim)) mask_patch = np.sum(patches > 0, axis=axis) > np.prod(b0_block_size) // 2 patches = patches[mask_patch].reshape(-1, np.prod(b0_block_size)) if use_std: try: variance = nib.load(filename['std']).get_data()**2 * mask variance = np.broadcast_to(variance[..., None], data.shape) variance = extract_patches(variance, b0_block_size, overlap) axis = tuple(range(variance.ndim // 2, variance.ndim)) variance = np.median(variance, axis=axis)[mask_patch].ravel() print('variance shape', variance.shape) except IOError: print('Volume {} not found!'.format(filename['std'])) variance = [None] else: variance = [None] # check to build with np.r_ the whole list from stringnames instead train_list += [patches] variance_large += list(variance) # train_data.extend(patches) # variance_large.extend(variance) print('train', len(train_list), data.shape, b0_block_size, overlap, patches.shape) del data, mask, patches, variance print('Fed everything in') lengths = [l.shape[0] for l in train_list] train_data = np.empty((np.sum(lengths), np.prod(b0_block_size))) print(train_data.shape) step = 0 for i in range(len(train_list)): length = lengths[i] idx = slice(step, step + length) train_data[idx] = train_list[i].reshape(-1, np.prod(b0_block_size)) step += length del train_list # if center: # train_data -= train_data.mean(axis=1, keepdims=True) # we have variance as a N elements list - so check one element to see if it's an array if variance_large[0] is not None: variance_large = np.asarray(variance_large).ravel() else: variance_large = None savename = 'Dic_' + outfilename + '_size_{}.npy'.format( block_size).replace(' ', '') D = online_DL(train_data, ncores=ncores, positivity=positivity, fit_intercept=fit_intercept, standardize=True, nlambdas=100, niter=niter, batchsize=batchsize, n_atoms=n_atoms, variance=variance_large, progressbar=True, disable_mkl=True, saveback=savename, use_joblib=False) return D
def nlsam_denoise(data, sigma, bvals, bvecs, block_size, mask=None, is_symmetric=False, n_cores=None, subsample=True, n_iter=10, b0_threshold=10, verbose=False, mp_method=None): """Main nlsam denoising function which sets up everything nicely for the local block denoising. Input ----------- data : ndarray Input volume to denoise. sigma : ndarray Noise standard deviation estimation at each voxel. Converted to variance internally. bvals : 1D array the N b-values associated to each of the N diffusion volume. bvecs : N x 3 2D array the N 3D vectors for each acquired diffusion gradients. block_size : tuple, length = data.ndim Patch size + number of angular neighbors to process at once as similar data. Optional parameters ------------------- mask : ndarray, default None Restrict computations to voxels inside the mask to reduce runtime. is_symmetric : bool, default False If True, assumes that for each coordinate (x, y, z) in bvecs, (-x, -y, -z) was also acquired. n_cores : int, default None Number of processes to use for the denoising. Default is to use all available cores. subsample : bool, default True If True, find the smallest subset of indices required to process each dwi at least once. n_iter : int, default 10 Maximum number of iterations for the reweighted l1 solver. b0_threshold : int, default 10 A b-value below b0_threshold will be considered as a b0 image. verbose : bool, default False print useful messages. mp_method : string Dispatch method for multiprocessing, Output ----------- data_denoised : ndarray The denoised dataset """ if verbose: logger.setLevel(logging.INFO) if mask is None: mask = np.ones(data.shape[:-1], dtype=np.bool) if data.shape[:-1] != mask.shape: raise ValueError( 'data shape is {}, but mask shape {} is different!'.format( data.shape, mask.shape)) if data.shape[:-1] != sigma.shape: raise ValueError( 'data shape is {}, but sigma shape {} is different!'.format( data.shape, sigma.shape)) if len(block_size) != len(data.shape): raise ValueError( 'Block shape {} and data shape {} are not of the same ' 'length'.format(data.shape, block_size.shape)) b0_loc = tuple(np.where(bvals <= b0_threshold)[0]) num_b0s = len(b0_loc) variance = sigma**2 orig_shape = data.shape logger.info("Found {} b0s at position {}".format(str(num_b0s), str(b0_loc))) # Average multiple b0s, and just use the average for the rest of the script # patching them in at the end if num_b0s > 1: mean_b0 = np.mean(data[..., b0_loc], axis=-1) dwis = tuple(np.where(bvals > b0_threshold)[0]) data = data[..., dwis] bvals = np.take(bvals, dwis, axis=0) bvecs = np.take(bvecs, dwis, axis=0) rest_of_b0s = b0_loc[1:] b0_loc = b0_loc[0] data = np.insert(data, b0_loc, mean_b0, axis=-1) bvals = np.insert(bvals, b0_loc, [0.], axis=0) bvecs = np.insert(bvecs, b0_loc, [0., 0., 0.], axis=0) b0_loc = tuple([b0_loc]) num_b0s = 1 else: rest_of_b0s = None # Double bvecs to find neighbors with assumed symmetry if needed if is_symmetric: logger.info('Data is assumed to be already symmetric.') sym_bvecs = np.delete(bvecs, b0_loc, axis=0) else: sym_bvecs = np.vstack( (np.delete(bvecs, b0_loc, axis=0), np.delete(-bvecs, b0_loc, axis=0))) neighbors = (angular_neighbors(sym_bvecs, block_size[-1] - num_b0s) % (data.shape[-1] - num_b0s))[:data.shape[-1] - num_b0s] # Full overlap for dictionary learning overlap = np.array(block_size, dtype=np.int16) - 1 b0 = np.squeeze(data[..., b0_loc]) data = np.delete(data, b0_loc, axis=-1) indexes = [(i, ) + tuple(neighbors[i]) for i in range(len(neighbors))] if subsample: indexes = greedy_set_finder(indexes) b0_block_size = tuple(block_size[:-1]) + ((block_size[-1] + num_b0s, )) denoised_shape = data.shape[:-1] + (data.shape[-1] + num_b0s, ) data_denoised = np.zeros(denoised_shape, np.float32) # Put all idx + b0 in this array in each iteration to_denoise = np.empty(data.shape[:-1] + (block_size[-1] + 1, ), dtype=np.float64) for i, idx in enumerate(indexes): dwi_idx = tuple(np.where(idx <= b0_loc, idx, np.array(idx) + num_b0s)) logger.info('Now denoising volumes {} / block {} out of {}.'.format( idx, i + 1, len(indexes))) to_denoise[..., 0] = np.copy(b0) to_denoise[..., 1:] = data[..., idx] data_denoised[..., b0_loc + dwi_idx] += local_denoise(to_denoise, b0_block_size, overlap, variance, n_iter=n_iter, mask=mask, dtype=np.float64, n_cores=n_cores, verbose=verbose, mp_method=mp_method) divider = np.bincount(np.array(indexes, dtype=np.int16).ravel()) divider = np.insert(divider, b0_loc, len(indexes)) data_denoised = data_denoised[:orig_shape[0], :orig_shape[1], : orig_shape[2], :orig_shape[3]] / divider # Put back the original number of b0s if rest_of_b0s is not None: b0_denoised = np.squeeze(data_denoised[..., b0_loc]) data_denoised_insert = np.empty(orig_shape, dtype=np.float32) n = 0 for i in range(orig_shape[-1]): if i in rest_of_b0s: data_denoised_insert[..., i] = b0_denoised n += 1 else: data_denoised_insert[..., i] = data_denoised[..., i - n] data_denoised = data_denoised_insert return data_denoised
def main(): if len(sys.argv) == 1: usage = 'Usage : path_data path_D scanner to_scanner block_size block_up use_std positivity' print(usage) sys.exit(1) path, path_D, scanner, to_scanner, block_size, block_up, use_std, positivity = sys.argv[ 1:] path = os.path.expanduser(path) path_D = os.path.expanduser(path_D) block_size = literal_eval(block_size) block_up = literal_eval(block_up) use_std = use_std.lower() == 'true' positivity = positivity.lower() == 'true' ncores = 100 # N = 1 # we should find a way to actually estimate that bias_correct_std = False fix_mean = False # reweighting = True # fit_intercept = False fit_intercept = True # center = False center = True use_crossval = True D = np.load(path_D) datasets = [] to_datasets = [] # print(D.min(), D.max()) # Allow the all keyword to feed in all the scanners # if it's a regular scanner, we put it in a list to parse it properly afterwards if scanner.lower() == 'all/st': scanners = 'GE/st', 'Prisma/st', 'Connectom/st' elif scanner.lower() == 'all/sa': scanners = 'Prisma/sa', 'Connectom/sa' else: scanners = [scanner] # print(scanner) for root, dirs, files in os.walk(path): dirs.sort() for name in files: for scanner in scanners: if (name == 'dwi.nii' or name == 'dwi_fw.nii') and scanner.lower() in root.lower(): datasets += [os.path.join(root, name)] if (name == 'dwi.nii' or name == 'dwi_fw.nii' ) and to_scanner.lower() in root.lower(): to_datasets += [os.path.join(root, name)] if len(datasets) != len(to_datasets): raise ValueError('Size mismatch between lists! {}, {}'.format( len(datasets), len(to_datasets))) for dataset, to_dataset in zip(datasets, to_datasets): print('Now rebuilding {}'.format(dataset)) predicted_D = path_D.replace('.npy', '') output_filename = dataset.replace( '.nii', '_predicted_' + predicted_D + '.nii.gz') output_filename = output_filename.replace('.nii.gz', '_recon.nii.gz') if center: output_filename = output_filename.replace('_predicted_', '_predicted_center_') if use_crossval: output_filename = output_filename.replace('_predicted_', '_predicted_with_cv2_') else: output_filename = output_filename.replace('_predicted_', '_predicted_with_aic_') if use_std: output_filename = output_filename.replace('_predicted_', '_predicted_with_std_') if positivity: output_filename = output_filename.replace('_predicted_', '_predicted_pos_') if fix_mean: output_filename = output_filename.replace('_predicted_', '_predicted_meanfix_') if not fit_intercept: output_filename = output_filename.replace( '_predicted_', '_predicted_no_intercept_') if os.path.isfile(output_filename): print('File already exists! Skipping {}'.format(output_filename)) else: # if True: vol = nib.load(dataset) # once we have loaded the data, replace the name if we used a fw dataset since everything else will match dataset = dataset.replace('_fw', '') to_dataset = to_dataset.replace('_fw', '') indexer = get_indexer(dataset) data = vol.get_data()[..., indexer] # affine = vol.affine # header = vol.header mask = nib.load(dataset.replace('.nii', '_mask.nii.gz')).get_data() std = nib.load(dataset.replace('.nii', '_std.nii.gz')).get_data() bvals = np.loadtxt(dataset.replace('.nii', '.bval'))[indexer] bvecs = np.loadtxt(dataset.replace('.nii', '.bvec'))[indexer] # data = data[:, :, 30:36] # mask = mask[:, :, 30:36] # if use_std: # std = std[:, :, 30:36] if center: # # data_mean = np.mean(data, axis=-1, keepdims=True) # with warnings.catch_warnings(): # warnings.filterwarnings('ignore') # data_mean = np.nanmean(np.where(data != 0, data, np.nan), axis=-1, keepdims=True) # # print(np.sum(np.isfinite(data_mean)), np.sum(np.isnan(data_mean))) # data_mean[np.isnan(data_mean)] = 0 # we need to pull down the 3D volume mean for the upsampling part to make sense though data_mean = np.nanmean(np.where(data != 0, data, np.nan), axis=(0, 1, 2), keepdims=True) data -= data_mean # print('centering removed {}'.format(data_mean.mean())) # 1/0 if use_std: if bias_correct_std: print('this is now disabled somehow :/') # print(std.shape, data.shape) # std = np.broadcast_to(std[..., None], data.shape) # mask_4D = np.broadcast_to(mask[..., None], data.shape) # std = np.median(corrected_sigma(data, std, mask_4D, N), axis=-1) variance = std**2 else: std = None variance = None # number of blocks is without b0, so we +1 the last dimension if len(block_up) < data.ndim: # last_size = int(D.shape[0] / np.prod(block_size)) current_block_size = block_size[:-1] + (block_size[-1] + 1, ) current_block_up = block_up + (block_size[-1] + 1, ) else: current_block_size = block_size[:-1] + (block_size[-1] + 1, ) current_block_up = block_up[:-1] + (block_up[-1] + 1, ) print('Output filename is {}'.format(output_filename)) print(D.shape, current_block_size, current_block_up) factor = np.divide(current_block_up, current_block_size) split_b0s = True b0_threshold = 20 b0_loc = np.where(bvals <= b0_threshold)[0] dwis = np.where(bvals > b0_threshold)[0] num_b0s = len(b0_loc) # We also convert bvecs associated with b0s to exactly (0,0,0), which # is not always the case when we hack around with the scanner. bvecs = np.where(bvals[:, None] <= b0_threshold, 0, bvecs) # Average all b0s if we don't split them in the training set if num_b0s > 1 and not split_b0s: num_b0s = 1 data[..., b0_loc] = np.mean(data[..., b0_loc], axis=-1, keepdims=True) # Split the b0s in a cyclic fashion along the training data # If we only had one, cycle just return b0_loc indefinitely, # else we go through all indexes. np.random.shuffle(b0_loc) split_b0s_idx = cycle(b0_loc) sym_bvecs = np.vstack((bvecs, -bvecs)) neighbors = angular_neighbors( sym_bvecs, current_block_size[-1] - 2) % data.shape[-1] neighbors = neighbors[:data.shape[ -1]] # everything was doubled for symmetry full_indexes = [(dwi, ) + tuple(neighbors[dwi]) for dwi in range(data.shape[-1]) if dwi in dwis] indexes = greedy_set_finder(full_indexes) # If we have more b0s than indexes, then we have to add a few more blocks since # we won't do a full cycle. If we have more b0s than indexes after that, then it breaks. if num_b0s > len(indexes): the_rest = [ rest for rest in full_indexes if rest not in indexes ] indexes += the_rest[:(num_b0s - len(indexes))] if num_b0s > len(indexes): error = ( 'Seems like you still have more b0s {} than available blocks {},' ' either average them or deactivate subsampling.'.format( num_b0s, len(indexes))) raise ValueError(error) # Stuff happens here / we only pimp up dwis, not b0s # actually we can't really pimp 4D stuff, SH are there for that predicted_size = (int(data.shape[0] * factor[0]), int(data.shape[1] * factor[1]), int(data.shape[2] * factor[2]), data.shape[-1]) predicted = np.zeros(predicted_size, dtype=np.float32) divider = np.zeros(predicted.shape[-1]) print(data.shape, predicted.shape, factor) # Put all idx + b0 in this array in each iteration to_denoise = np.empty(data.shape[:-1] + (current_block_size[-1], ), dtype=np.float64) print(to_denoise.shape) for i, idx in enumerate(indexes, start=1): b0_loc = tuple((next(split_b0s_idx), )) to_denoise[..., 0] = data[..., b0_loc].squeeze() to_denoise[..., 1:] = data[..., idx] divider[list(b0_loc + idx)] += 1 print('Now denoising volumes {} / block {} out of {}.'.format( b0_loc + idx, i, len(indexes))) predicted[..., b0_loc + idx] += rebuild( to_denoise, mask, D, block_size=current_block_size, block_up=current_block_up, ncores=ncores, positivity=positivity, fix_mean=fix_mean, # center=center, fit_intercept=fit_intercept, use_crossval=use_crossval, variance=variance) # break predicted /= divider if center: predicted += data_mean # el cheapo mask after upsampling predicted[predicted == data_mean] = 0 # clip negatives, which happens at the borders predicted.clip(min=0., out=predicted) # header voxel size is all screwed up, so replace with the destination header and affine by the matching dataset to_affine = nib.load(to_dataset).affine to_header = nib.load(to_dataset).header # header['pixdim'] = to_header['pixdim'] imgfile = nib.Nifti1Image(predicted, to_affine, to_header) nib.save(imgfile, output_filename) # subsample to common bvals/bvecs/tr/te to_data = nib.load(to_dataset).get_data() to_bvals = np.loadtxt(to_dataset.replace('.nii', '.bval')) to_bvecs = np.loadtxt(to_dataset.replace('.nii', '.bvec')) to_indexer = get_indexer(to_dataset) to_data = to_data[..., to_indexer] to_bvals = to_bvals[to_indexer] to_bvecs = to_bvecs[to_indexer] np.savetxt(to_dataset.replace('dwi.nii', 'dwi_subsample.bval'), to_bvals, fmt='%1.4f') np.savetxt(to_dataset.replace('dwi.nii', 'dwi_subsample.bvec'), to_bvecs, fmt='%1.4f') # # match to the new bvecs # rotated = match_bvecs(predicted, bvals, bvecs, to_bvals, to_bvecs).clip(min=0.) # imgfile = nib.Nifti1Image(rotated, to_affine, to_header) # nib.save(imgfile, output_filename.replace('predicted', 'rotated')) # # match to the new bvecs with nnls # rotated_nnls = match_bvecs(predicted, bvals, bvecs, to_bvals, to_bvecs, use_nnls=True) # imgfile = nib.Nifti1Image(rotated_nnls, to_affine, to_header) # nib.save(imgfile, output_filename.replace('predicted', 'rotated_nnls')) del data, to_data, predicted, vol, imgfile, mask, variance, std del to_denoise, divider