Esempio n. 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.input_sh)
    assert_outputs_exist(parser, args, args.output_name)

    input_basis = args.sh_basis
    output_basis = 'descoteaux07' if input_basis == 'tournier07' else 'tournier07'

    sph_harm_basis_ori = sph_harm_lookup.get(input_basis)
    sph_harm_basis_des = sph_harm_lookup.get(output_basis)

    sphere = get_sphere('repulsion724').subdivide(1)
    img = nib.load(args.input_sh)
    data = img.get_data()
    sh_order = find_order_from_nb_coeff(data)

    b_ori, m_ori, n_ori = sph_harm_basis_ori(sh_order, sphere.theta,
                                             sphere.phi)
    b_des, m_des, n_des = sph_harm_basis_des(sh_order, sphere.theta,
                                             sphere.phi)
    l_des = -n_des * (n_des + 1)
    inv_b_des = smooth_pinv(b_des, 0 * l_des)

    indices = np.argwhere(np.any(data, axis=3))
    for i, ind in enumerate(indices):
        ind = tuple(ind)
        sf_1 = np.dot(data[ind], b_ori.T)
        data[ind] = np.dot(sf_1, inv_b_des.T)

    img = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(nib.Nifti1Image(data, img.affine, img.header), args.output_name)
Esempio n. 2
0
def get_ventricles_max_fodf(data, fa, md, zoom, args):
    order = find_order_from_nb_coeff(data)
    sphere = get_sphere('repulsion100')
    b_matrix = get_b_matrix(order, sphere, args.sh_basis)
    sum_of_max = 0
    count = 0

    mask = np.zeros(data.shape[:-1])

    if np.min(data.shape[:-1]) > 40:
        step = 20
    else:
        if np.min(data.shape[:-1]) > 20:
            step = 10
        else:
            step = 5

    # 1000 works well at 2x2x2 = 8 mm^3
    # Hence, we multiply by the volume of a voxel
    vol = (zoom[0] * zoom[1] * zoom[2])
    if vol != 0:
        max_number_of_voxels = old_div(1000 * 8, vol)
    else:
        max_number_of_voxels = 1000

    all_i = list(
        range(int(data.shape[0] / 2) - step,
              int(data.shape[0] / 2) + step))
    all_j = list(
        range(int(data.shape[1] / 2) - step,
              int(data.shape[1] / 2) + step))
    all_k = list(
        range(int(data.shape[2] / 2) - step,
              int(data.shape[2] / 2) + step))
    for i in all_i:
        for j in all_j:
            for k in all_k:
                if count > max_number_of_voxels - 1:
                    continue
                if fa[i, j, k] < args.fa_threshold \
                        and md[i, j, k] > args.md_threshold:
                    sf = np.dot(data[i, j, k], b_matrix.T)
                    sum_of_max += sf.max()
                    count += 1
                    mask[i, j, k] = 1

    logging.debug('Number of voxels detected: {}'.format(count))
    if count == 0:
        logging.warning('No voxels found for evaluation! Change your fa '
                        'and/or md thresholds')
        return 0, mask

    logging.debug('Average max fodf value: {}'.format(sum_of_max / count))
    return sum_of_max / count, mask
Esempio n. 3
0
def _get_direction_getter(args, mask_data):
    sh_data = nib.load(args.sh_file).get_data().astype('float64')
    sphere = HemiSphere.from_sphere(get_sphere(args.sphere))
    theta = get_theta(args.theta, args.algo)

    if args.algo in ['det', 'prob']:
        if args.algo == 'det':
            dg_class = DeterministicMaximumDirectionGetter
        else:
            dg_class = ProbabilisticDirectionGetter
        return dg_class.from_shcoeff(shcoeff=sh_data,
                                     max_angle=theta,
                                     sphere=sphere,
                                     basis_type=args.sh_basis,
                                     relative_peak_threshold=args.sf_threshold)

    # Code for type EUDX. We don't use peaks_from_model
    # because we want the peaks from the provided sh.
    sh_shape_3d = sh_data.shape[:-1]
    npeaks = 5
    peak_dirs = np.zeros((sh_shape_3d + (npeaks, 3)))
    peak_values = np.zeros((sh_shape_3d + (npeaks, )))
    peak_indices = np.full((sh_shape_3d + (npeaks, )), -1, dtype='int')
    b_matrix = get_b_matrix(find_order_from_nb_coeff(sh_data), sphere,
                            args.sh_basis)

    for idx in np.ndindex(sh_shape_3d):
        if not mask_data[idx]:
            continue

        directions, values, indices = get_maximas(sh_data[idx], sphere,
                                                  b_matrix, args.sf_threshold,
                                                  0)
        if values.shape[0] != 0:
            n = min(npeaks, values.shape[0])
            peak_dirs[idx][:n] = directions[:n]
            peak_values[idx][:n] = values[:n]
            peak_indices[idx][:n] = indices[:n]

    dg = PeaksAndMetrics()
    dg.sphere = sphere
    dg.peak_dirs = peak_dirs
    dg.peak_values = peak_values
    dg.peak_indices = peak_indices
    dg.ang_thr = theta
    dg.qa_thr = args.sf_threshold
    return dg
Esempio n. 4
0
def get_maps(data, mask, args, npeaks=5):
    nufo_map = np.zeros(data.shape[0:3])
    afd_map = np.zeros(data.shape[0:3])
    afd_sum = np.zeros(data.shape[0:3])

    peaks_dirs = np.zeros(list(data.shape[0:3]) + [npeaks, 3])
    order = find_order_from_nb_coeff(data)
    sphere = get_sphere(args.sphere)
    b_matrix = get_b_matrix(order, sphere, args.sh_basis)

    for index in ndindex(data.shape[:-1]):
        if mask[index]:
            if np.isnan(data[index]).any():
                nufo_map[index] = 0
                afd_map[index] = 0
            else:
                maximas, afd, _ = get_maximas(
                    data[index], sphere, b_matrix, args.r_threshold, args.at)
                # sf = np.dot(data[index], B.T)

                n = min(npeaks, maximas.shape[0])
                nufo_map[index] = maximas.shape[0]
                if n == 0:
                    afd_map[index] = 0.0
                    nufo_map[index] = 0.0
                else:
                    afd_map[index] = afd.max()
                    peaks_dirs[index][:n] = maximas[:n]

                    # sum of all coefficients, sqrt(power spectrum)
                    # sum C^2 = sum fODF^2
                    afd_sum[index] = np.sqrt(np.dot(data[index], data[index]))

                    # sum of all peaks contributions to the afd
                    # integral of all the lobes. Numerical sum.
                    # With an infinite number of SH, this should == to afd_sum
                    # sf[np.nonzero(sf < args.at)] = 0.
                    # afd_sum[index] = sf.sum()/n*4*np.pi/B.shape[0]x

    return nufo_map, afd_map, afd_sum, peaks_dirs
Esempio n. 5
0
def afd_and_rd_sums_along_streamlines(streamlines, fodf_data, fodf_basis,
                                      jump):
    order = find_order_from_nb_coeff(fodf_data)
    sphere = get_repulsion200_sphere()
    b_matrix, _, n = get_b_matrix(order, sphere, fodf_basis, return_all=True)
    legendre0_at_n = lpn(order, 0)[0][n]
    sphere_norm = np.linalg.norm(sphere.vertices)
    if sphere_norm == 0:
        raise ValueError(
            "Norm of {} triangulated sphere is 0.".format('repulsion200'))

    afd_sum_map = np.zeros(shape=fodf_data.shape[:-1])
    rd_sum_map = np.zeros(shape=fodf_data.shape[:-1])
    count_map = np.zeros(shape=fodf_data.shape[:-1])
    for streamline in streamlines:
        for point_idx, (p0, p1) in enumerate(pairwise(streamline)):
            if point_idx % jump != 0:
                continue

            closest_vertex_idx = _nearest_neighbor_idx_on_sphere(
                p1 - p0, sphere, sphere_norm)
            if closest_vertex_idx == -1:
                # Points were identical so skip them
                continue

            vox_idx = _get_nearest_voxel_index(p0, p1)

            b_at_idx = b_matrix[closest_vertex_idx]
            fodf_at_index = fodf_data[vox_idx]

            afd_val = np.dot(b_at_idx, fodf_at_index)

            p_matrix = np.eye(fodf_at_index.shape[0]) * legendre0_at_n
            rd_val = np.dot(np.dot(b_at_idx.T, p_matrix), fodf_at_index)

            afd_sum_map[vox_idx] += afd_val
            rd_sum_map[vox_idx] += rd_val
            count_map[vox_idx] += 1

    return afd_sum_map, rd_sum_map, count_map
def main():
    logging.basicConfig(level=logging.INFO)
    parser = _build_arg_parser()
    args = parser.parse_args()

    required = [args.bundle_filename, args.fod_filename, args.mask_filename]
    assert_inputs_exist(parser, required)

    out_efod = os.path.join(args.output_dir,
                            '{0}efod.nii.gz'.format(args.output_prefix))
    out_priors = os.path.join(args.output_dir,
                              '{0}priors.nii.gz'.format(args.output_prefix))
    out_todi_mask = os.path.join(
        args.output_dir, '{0}todi_mask.nii.gz'.format(args.output_prefix))
    out_endpoints_mask = os.path.join(
        args.output_dir, '{0}endpoints_mask.nii.gz'.format(args.output_prefix))
    required = [out_efod, out_priors, out_todi_mask, out_endpoints_mask]
    assert_outputs_exist(parser, args, required)

    if args.output_dir and not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    img_sh = nib.load(args.fod_filename)
    sh_shape = img_sh.shape
    sh_order = find_order_from_nb_coeff(sh_shape)
    img_mask = nib.load(args.mask_filename)

    sft = load_tractogram(args.bundle_filename,
                          args.fod_filename,
                          trk_header_check=True)
    sft.to_vox()
    streamlines = sft.streamlines
    if len(streamlines) < 1:
        raise ValueError('The input bundle contains no streamline.')

    # Compute TODI from streamlines
    with TrackOrientationDensityImaging(img_mask.shape,
                                        'repulsion724') as todi_obj:
        todi_obj.compute_todi(streamlines, length_weights=True)
        todi_obj.smooth_todi_dir()
        todi_obj.smooth_todi_spatial(sigma=args.todi_sigma)

        # Fancy masking of 1d indices to limit spatial dilation to WM
        sub_mask_3d = np.logical_and(
            img_mask.get_data(), todi_obj.reshape_to_3d(todi_obj.get_mask()))
        sub_mask_1d = sub_mask_3d.flatten()[todi_obj.get_mask()]
        todi_sf = todi_obj.get_todi()[sub_mask_1d]**2

    # The priors should always be between 0 and 1
    # A minimum threshold is set to prevent misaligned FOD from disappearing
    todi_sf /= np.max(todi_sf, axis=-1, keepdims=True)
    todi_sf[todi_sf < args.sf_threshold] = args.sf_threshold

    # Memory friendly saving, as soon as possible saving then delete
    priors_3d = np.zeros(sh_shape)
    sphere = get_sphere('repulsion724')
    priors_3d[sub_mask_3d] = sf_to_sh(todi_sf,
                                      sphere,
                                      sh_order=sh_order,
                                      basis_type=args.sh_basis)
    nib.save(nib.Nifti1Image(priors_3d, img_mask.affine), out_priors)
    del priors_3d

    input_sh_3d = img_sh.get_data().astype(np.float)
    input_sf_1d = sh_to_sf(input_sh_3d[sub_mask_3d],
                           sphere,
                           sh_order=sh_order,
                           basis_type=args.sh_basis)

    # Creation of the enhanced-FOD (direction-wise multiplication)
    mult_sf_1d = input_sf_1d * todi_sf
    del todi_sf

    input_max_value = np.max(input_sf_1d, axis=-1, keepdims=True)
    mult_max_value = np.max(mult_sf_1d, axis=-1, keepdims=True)
    mult_positive_mask = np.squeeze(mult_max_value) > 0.0
    mult_sf_1d[mult_positive_mask] = mult_sf_1d[mult_positive_mask] * \
        input_max_value[mult_positive_mask] / \
        mult_max_value[mult_positive_mask]

    # Memory friendly saving
    input_sh_3d[sub_mask_3d] = sf_to_sh(mult_sf_1d,
                                        sphere,
                                        sh_order=sh_order,
                                        basis_type=args.sh_basis)
    nib.save(nib.Nifti1Image(input_sh_3d, img_mask.affine), out_efod)
    del input_sh_3d

    nib.save(nib.Nifti1Image(sub_mask_3d.astype(np.int16), img_mask.affine),
             out_todi_mask)

    endpoints_mask = np.zeros(img_mask.shape, dtype=np.int16)
    for streamline in streamlines:
        if img_mask.get_data()[tuple(streamline[0].astype(np.int16))]:
            endpoints_mask[tuple(streamline[0].astype(np.int16))] = 1
            endpoints_mask[tuple(streamline[-1].astype(np.int16))] = 1
    nib.save(nib.Nifti1Image(endpoints_mask, img_mask.affine),
             out_endpoints_mask)
Esempio n. 7
0
def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, length_weighting):
    """
    Compute the mean Apparent Fiber Density (AFD) and mean Radial fODF (radfODF)
    maps along a bundle.

    Parameters
    ----------
    sft : StatefulTractogram
        StatefulTractogram containing the streamlines needed.
    fodf : nibabel.image
        fODF with shape (X, Y, Z, #coeffs).
        #coeffs depend on the sh_order.
    fodf_basis : string
        Has to be descoteaux07 or tournier07.
    length_weighting : bool
        If set, will weigh the AFD values according to segment lengths.

    Returns
    -------
    afd_sum_map : np.array
        AFD map.
    rd_sum_map : np.array
        fdAFD map.
    weight_map : np.array
        Segment lengths.
    """

    sft.to_vox()
    sft.to_corner()

    fodf_data = np.asanyarray(fodf.dataobj)
    order = find_order_from_nb_coeff(fodf_data)
    sphere = get_sphere('repulsion724')
    b_matrix, _, n = get_b_matrix(order, sphere, fodf_basis, return_all=True)
    legendre0_at_n = lpn(order, 0)[0][n]
    sphere_norm = np.linalg.norm(sphere.vertices)

    afd_sum_map = np.zeros(shape=fodf_data.shape[:-1])
    rd_sum_map = np.zeros(shape=fodf_data.shape[:-1])
    weight_map = np.zeros(shape=fodf_data.shape[:-1])

    p_matrix = np.eye(fodf_data.shape[3]) * legendre0_at_n
    all_crossed_indices = grid_intersections(sft.streamlines)
    for crossed_indices in all_crossed_indices:
        segments = crossed_indices[1:] - crossed_indices[:-1]
        seg_lengths = np.linalg.norm(segments, axis=1)

        # Remove points where the segment is zero.
        # This removes numpy warnings of division by zero.
        non_zero_lengths = np.nonzero(seg_lengths)[0]
        segments = segments[non_zero_lengths]
        seg_lengths = seg_lengths[non_zero_lengths]

        test = np.dot(segments, sphere.vertices.T)
        test2 = (test.T / (seg_lengths * sphere_norm)).T
        angles = np.arccos(test2)
        sorted_angles = np.argsort(angles, axis=1)
        closest_vertex_indices = sorted_angles[:, 0]

        # Those starting points are used for the segment vox_idx computations
        strl_start = crossed_indices[non_zero_lengths]
        vox_indices = (strl_start + (0.5 * segments)).astype(int)

        normalization_weights = np.ones_like(seg_lengths)
        if length_weighting:
            normalization_weights = seg_lengths / np.linalg.norm(
                fodf.header.get_zooms()[:3])

        for vox_idx, closest_vertex_index, norm_weight in zip(
                vox_indices, closest_vertex_indices, normalization_weights):
            vox_idx = tuple(vox_idx)
            b_at_idx = b_matrix[closest_vertex_index]
            fodf_at_index = fodf_data[vox_idx]

            afd_val = np.dot(b_at_idx, fodf_at_index)
            rd_val = np.dot(np.dot(b_at_idx.T, p_matrix), fodf_at_index)

            afd_sum_map[vox_idx] += afd_val * norm_weight
            rd_sum_map[vox_idx] += rd_val * norm_weight
            weight_map[vox_idx] += norm_weight

    rd_sum_map[rd_sum_map < 0.] = 0.

    return afd_sum_map, rd_sum_map, weight_map
def _get_direction_getter(args):
    odf_data = nib.load(args.in_odf).get_fdata(dtype=np.float32)
    sphere = HemiSphere.from_sphere(get_sphere(args.sphere))
    theta = get_theta(args.theta, args.algo)

    non_zeros_count = np.count_nonzero(np.sum(odf_data, axis=-1))
    non_first_val_count = np.count_nonzero(np.argmax(odf_data, axis=-1))

    if args.algo in ['det', 'prob']:
        if non_first_val_count / non_zeros_count > 0.5:
            logging.warning('Input detected as peaks. Input should be'
                            'fodf for det/prob, verify input just in case.')
        if args.algo == 'det':
            dg_class = DeterministicMaximumDirectionGetter
        else:
            dg_class = ProbabilisticDirectionGetter
        return dg_class.from_shcoeff(
            shcoeff=odf_data, max_angle=theta, sphere=sphere,
            basis_type=args.sh_basis,
            relative_peak_threshold=args.sf_threshold)
    elif args.algo == 'eudx':
        # Code for type EUDX. We don't use peaks_from_model
        # because we want the peaks from the provided sh.
        odf_shape_3d = odf_data.shape[:-1]
        dg = PeaksAndMetrics()
        dg.sphere = sphere
        dg.ang_thr = theta
        dg.qa_thr = args.sf_threshold

        # Heuristic to find out if the input are peaks or fodf
        # fodf are always around 0.15 and peaks around 0.75
        if non_first_val_count / non_zeros_count > 0.5:
            logging.info('Input detected as peaks.')
            nb_peaks = odf_data.shape[-1] // 3
            slices = np.arange(0, 15+1, 3)
            peak_values = np.zeros(odf_shape_3d+(nb_peaks,))
            peak_indices = np.zeros(odf_shape_3d+(nb_peaks,))

            for idx in np.argwhere(np.sum(odf_data, axis=-1)):
                idx = tuple(idx)
                for i in range(nb_peaks):
                    peak_values[idx][i] = np.linalg.norm(
                        odf_data[idx][slices[i]:slices[i+1]], axis=-1)
                    peak_indices[idx][i] = sphere.find_closest(
                        odf_data[idx][slices[i]:slices[i+1]])

            dg.peak_dirs = odf_data
        else:
            logging.info('Input detected as fodf.')
            npeaks = 5
            peak_dirs = np.zeros((odf_shape_3d + (npeaks, 3)))
            peak_values = np.zeros((odf_shape_3d + (npeaks, )))
            peak_indices = np.full((odf_shape_3d + (npeaks, )), -1, dtype='int')
            b_matrix = get_b_matrix(
                find_order_from_nb_coeff(odf_data), sphere, args.sh_basis)

            for idx in np.argwhere(np.sum(odf_data, axis=-1)):
                idx = tuple(idx)
                directions, values, indices = get_maximas(odf_data[idx],
                                                          sphere, b_matrix,
                                                          args.sf_threshold, 0)
                if values.shape[0] != 0:
                    n = min(npeaks, values.shape[0])
                    peak_dirs[idx][:n] = directions[:n]
                    peak_values[idx][:n] = values[:n]
                    peak_indices[idx][:n] = indices[:n]

            dg.peak_dirs = peak_dirs

        dg.peak_values = peak_values
        dg.peak_indices = peak_indices

        return dg
Esempio n. 9
0
    def track(self):
        """
        GPU streamlines generator yielding streamlines with corresponding
        seed positions one by one.
        """
        t0 = perf_counter()

        # Load the sphere
        sphere = get_sphere('symmetric724')

        # Convert theta to cos(theta)
        max_cos_theta = np.cos(np.deg2rad(self.theta))

        cl_kernel = CLKernel('track', 'tracking', 'local_tracking.cl')

        # Set tracking parameters
        # TODO: Add relative sf_threshold parameter.
        cl_kernel.set_define('IM_X_DIM', self.sh.shape[0])
        cl_kernel.set_define('IM_Y_DIM', self.sh.shape[1])
        cl_kernel.set_define('IM_Z_DIM', self.sh.shape[2])
        cl_kernel.set_define('IM_N_COEFFS', self.sh.shape[3])
        cl_kernel.set_define('N_DIRS', len(sphere.vertices))

        cl_kernel.set_define('N_THETAS', len(self.theta))
        cl_kernel.set_define('STEP_SIZE', '{}f'.format(self.step_size))
        cl_kernel.set_define('MAX_LENGTH', self.max_strl_points)
        cl_kernel.set_define('FORWARD_ONLY',
                             'true' if self.forward_only else 'false')

        # Create CL program
        n_input_params = 7
        n_output_params = 2
        cl_manager = CLManager(cl_kernel, n_input_params, n_output_params)

        # Input buffers
        # Constant input buffers
        cl_manager.add_input_buffer(0, self.sh)
        cl_manager.add_input_buffer(1, sphere.vertices)

        sh_order = find_order_from_nb_coeff(self.sh)
        B_mat = sh_to_sf_matrix(sphere,
                                sh_order,
                                self.sh_basis,
                                return_inv=False)
        cl_manager.add_input_buffer(2, B_mat)
        cl_manager.add_input_buffer(3, self.mask.astype(np.float32))

        cl_manager.add_input_buffer(6, max_cos_theta)

        logging.debug(
            'Initialized OpenCL program in {:.2f}s.'.format(perf_counter() -
                                                            t0))

        # Generate streamlines in batches
        t0 = perf_counter()
        nb_processed_streamlines = 0
        nb_valid_streamlines = 0
        for seed_batch in self.seed_batches:
            # Generate random values for sf sampling
            # TODO: Implement random number generator directly
            #       on the GPU to generate values on-the-fly.
            rand_vals = self.rng.uniform(
                0.0, 1.0, (len(seed_batch), self.max_strl_points))

            # Update buffers
            cl_manager.add_input_buffer(4, seed_batch)
            cl_manager.add_input_buffer(5, rand_vals)

            # output streamlines buffer
            cl_manager.add_output_buffer(
                0, (len(seed_batch), self.max_strl_points, 3))
            # output streamlines length buffer
            cl_manager.add_output_buffer(1, (len(seed_batch), 1))

            # Run the kernel
            tracks, n_points = cl_manager.run((len(seed_batch), 1, 1))
            n_points = n_points.squeeze().astype(np.int16)
            for (strl, seed, n_pts) in zip(tracks, seed_batch, n_points):
                if n_pts >= self.min_strl_points:
                    strl = strl[:n_pts]
                    nb_valid_streamlines += 1

                    # output is yielded so that we can use lazy tractogram.
                    yield strl, seed

            # per-batch logging information
            nb_processed_streamlines += len(seed_batch)
            logging.info('{0:>8}/{1} streamlines generated'.format(
                nb_processed_streamlines, self.n_seeds))

        logging.info('Tracked {0} streamlines in {1:.2f}s.'.format(
            nb_valid_streamlines,
            perf_counter() - t0))