def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, args.out_json)

    nbr_cpu = validate_nbr_processes(parser, args)
    pool = multiprocessing.Pool(nbr_cpu)
    bundles_references_tuple_extended = link_bundles_and_reference(
        parser, args, args.in_bundles)
    all_measures_dict = pool.map(compute_measures,
                                 bundles_references_tuple_extended)
    pool.close()
    pool.join()

    output_measures_dict = {}
    for measure_dict in all_measures_dict:
        # Empty bundle should not make the script crash
        if measure_dict is not None:
            for measure_name in measure_dict.keys():
                # Create an empty list first
                if measure_name not in output_measures_dict:
                    output_measures_dict[measure_name] = []
                output_measures_dict[measure_name].append(
                    measure_dict[measure_name])

    with open(args.out_json, 'w') as outfile:
        json.dump(output_measures_dict,
                  outfile,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       create_dir=True)

    keys = []
    for filename in args.in_hdf5:
        curr_file = h5py.File(filename, 'r')
        keys.extend(curr_file.keys())
        curr_file.close()

    nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes)
    if nbr_cpu == 1:
        for key in keys:
            _average_wrapper([args.in_hdf5, key, args.binary, args.out_dir])
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        _ = pool.map(_average_wrapper,
                     zip(itertools.repeat(args.in_hdf5),
                         keys,
                         itertools.repeat(args.binary),
                         itertools.repeat(args.out_dir)))
        pool.close()
        pool.join()
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_sh)
    assert_outputs_exist(parser, args, args.out_bingham)

    sh_im = nib.load(args.in_sh)
    data = sh_im.get_fdata()

    # validate number of processes
    nbr_processes = validate_nbr_processes(parser, args)
    logging.info('Number of processes: {}'.format(nbr_processes))

    t0 = time.perf_counter()
    logging.info('Fitting Bingham functions.')
    bingham = bingham_fit_sh(data,
                             args.max_lobes,
                             abs_th=args.at,
                             rel_th=args.rt,
                             min_sep_angle=args.min_sep_angle,
                             max_fit_angle=args.max_fit_angle,
                             nbr_processes=nbr_processes)
    t1 = time.perf_counter()
    logging.info('Fitting done in (s): {0}'.format(t1 - t0))
    nib.save(nib.Nifti1Image(bingham, sh_im.affine), args.out_bingham)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_fodf])
    assert_outputs_exist(parser, args, [args.out_hdf5])

    nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    fodf_img = nib.load(args.in_fodf)
    in_hdf5_file = h5py.File(args.in_hdf5, 'r')
    if not (np.allclose(
            in_hdf5_file.attrs['affine'], fodf_img.affine, atol=1e-03)
            and np.array_equal(in_hdf5_file.attrs['dimensions'],
                               fodf_img.shape[0:3])):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_hdf5, args.in_fodf))

    keys = list(in_hdf5_file.keys())
    in_hdf5_file.close()

    if nbr_cpu == 1:
        results_list = []
        for key in keys:
            results_list.append(
                _afd_rd_wrapper([
                    args.in_hdf5, key, fodf_img, args.sh_basis,
                    args.length_weighting
                ]))

    else:
        pool = multiprocessing.Pool(nbr_cpu)
        results_list = pool.map(
            _afd_rd_wrapper,
            zip(itertools.repeat(args.in_hdf5), keys,
                itertools.repeat(fodf_img), itertools.repeat(args.sh_basis),
                itertools.repeat(args.length_weighting)))
        pool.close()
        pool.join()

    shutil.copy(args.in_hdf5, args.out_hdf5)
    with h5py.File(args.out_hdf5, 'a') as out_hdf5_file:
        for key, afd_fixel, rd_fixel in results_list:
            group = out_hdf5_file[key]
            if 'afd_fixel' in group:
                del group['afd_fixel']
            group.create_dataset('afd_fixel', data=afd_fixel)
            if 'rd_fixel' in group:
                del group['rd_fixel']
            group.create_dataset('rd_fixel', data=rd_fixel)
Esempio n. 5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.out_fd = args.out_fd or 'fd.nii.gz'
        args.out_fs = args.out_fs or 'fs.nii.gz'
        args.out_ff = args.out_ff or 'ff.nii.gz'

    arglist = [args.out_fd, args.out_fs, args.out_ff]
    if args.not_all and not any(arglist):
        parser.error('At least one output file must be specified.')

    outputs = [args.out_fd, args.out_fs, args.out_ff]
    assert_inputs_exist(parser, args.in_bingham, args.mask)
    assert_outputs_exist(parser, args, outputs)

    bingham_im = nib.load(args.in_bingham)
    bingham = bingham_im.get_fdata()
    mask = get_data_as_mask(nib.load(args.mask), dtype=bool)\
        if args.mask else None

    nbr_processes = validate_nbr_processes(parser, args)

    t0 = time.perf_counter()
    logging.info('Computing fiber density.')
    fd = compute_fiber_density(bingham, m=args.nbr_integration_steps,
                               mask=mask, nbr_processes=nbr_processes)
    t1 = time.perf_counter()
    logging.info('FD computed in (s): {0}'.format(t1 - t0))
    if args.out_fd:
        nib.save(nib.Nifti1Image(fd, bingham_im.affine), args.out_fd)

    if args.out_fs:
        t0 = time.perf_counter()
        logging.info('Computing fiber spread.')
        fs = compute_fiber_spread(bingham, fd)
        t1 = time.perf_counter()
        logging.info('FS computed in (s): {0}'.format(t1 - t0))
        nib.save(nib.Nifti1Image(fs, bingham_im.affine), args.out_fs)

    if args.out_ff:
        t0 = time.perf_counter()
        logging.info('Computing fiber fraction.')
        ff = compute_fiber_fraction(fd)
        t1 = time.perf_counter()
        logging.info('FS computed in (s): {0}'.format(t1 - t0))
        nib.save(nib.Nifti1Image(ff, bingham_im.affine), args.out_ff)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_fodf])
    assert_outputs_exist(parser, args, [args.out_hdf5])

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    fodf_img = nib.load(args.in_fodf)

    nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes)
    in_hdf5_file = h5py.File(args.in_hdf5, 'r')
    keys = list(in_hdf5_file.keys())
    in_hdf5_file.close()
    if nbr_cpu == 1:
        results_list = []
        for key in keys:
            results_list.append(_afd_rd_wrapper([args.in_hdf5, key, fodf_img,
                                                 args.sh_basis,
                                                 args.length_weighting]))

    else:
        pool = multiprocessing.Pool(nbr_cpu)
        results_list = pool.map(_afd_rd_wrapper,
                                zip(itertools.repeat(args.in_hdf5),
                                    keys,
                                    itertools.repeat(fodf_img),
                                    itertools.repeat(args.sh_basis),
                                    itertools.repeat(args.length_weighting)))
        pool.close()
        pool.join()

    shutil.copy(args.in_hdf5, args.out_hdf5)
    out_hdf5_file = h5py.File(args.out_hdf5, 'a')
    for key, afd_fixel, rd_fixel in results_list:
        group = out_hdf5_file[key]
        if 'afd_fixel' in group:
            del group['afd_fixel']
        group.create_dataset('afd_fixel', data=afd_fixel)
        if 'rd_fixel' in group:
            del group['rd_fixel']
        group.create_dataset('rd_fixel', data=rd_fixel)
    out_hdf5_file.close()
Esempio n. 7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    # Checking args
    outputs = [args.out_sh]
    if args.out_sym:
        outputs.append(args.out_sym)
    assert_outputs_exist(parser, args, outputs)
    assert_inputs_exist(parser, args.in_sh)

    nbr_processes = validate_nbr_processes(parser, args)

    # Prepare data
    sh_img = nib.load(args.in_sh)
    data = sh_img.get_fdata(dtype=np.float32)

    sh_order, full_basis = get_sh_order_and_fullness(data.shape[-1])

    t0 = time.perf_counter()
    logging.info('Executing angle-aware bilateral filtering.')
    asym_sh = angle_aware_bilateral_filtering(
        data, sh_order=sh_order,
        sh_basis=args.sh_basis,
        in_full_basis=full_basis,
        sphere_str=args.sphere,
        sigma_spatial=args.sigma_spatial,
        sigma_angular=args.sigma_angular,
        sigma_range=args.sigma_range,
        use_gpu=args.use_gpu,
        nbr_processes=nbr_processes)
    t1 = time.perf_counter()
    logging.info('Elapsed time (s): {0}'.format(t1 - t0))

    logging.info('Saving filtered SH to file {0}.'.format(args.out_sh))
    nib.save(nib.Nifti1Image(asym_sh, sh_img.affine), args.out_sh)

    if args.out_sym:
        _, orders = sph_harm_ind_list(sh_order, full_basis=True)
        logging.info('Saving symmetric SH to file {0}.'.format(args.out_sym))
        nib.save(nib.Nifti1Image(asym_sh[..., orders % 2 == 0], sh_img.affine),
                 args.out_sym)
Esempio n. 8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_labels],
                        args.force_labels_list)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)

    measures_to_compute = []
    measures_output_filename = []
    if args.volume:
        measures_to_compute.append('volume')
        measures_output_filename.append(args.volume)
    if args.streamline_count:
        measures_to_compute.append('streamline_count')
        measures_output_filename.append(args.streamline_count)
    if args.length:
        measures_to_compute.append('length')
        measures_output_filename.append(args.length)
    if args.similarity:
        measures_to_compute.append('similarity')
        measures_output_filename.append(args.similarity[1])

    dict_maps_out_name = {}
    if args.maps is not None:
        for in_folder, out_name in args.maps:
            measures_to_compute.append(in_folder)
            dict_maps_out_name[in_folder] = out_name
            measures_output_filename.append(out_name)

    dict_metrics_out_name = {}
    if args.metrics is not None:
        for in_name, out_name in args.metrics:
            # Verify that all metrics are compatible with each other
            if not is_header_compatible(args.metrics[0][0], in_name):
                raise IOError('Metrics {} and  {} do not share a compatible '
                              'header'.format(args.metrics[0][0], in_name))

            # This is necessary to support more than one map for weighting
            measures_to_compute.append((in_name, nib.load(in_name)))
            dict_metrics_out_name[in_name] = out_name
            measures_output_filename.append(out_name)

    dict_lesion_out_name = {}
    if args.lesion_load is not None:
        in_name = args.lesion_load[0]
        lesion_img = nib.load(in_name)
        lesion_data = get_data_as_mask(lesion_img, dtype=bool)
        lesion_atlas, _ = ndi.label(lesion_data)
        measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]),
                                    nib.Nifti1Image(lesion_atlas,
                                                    lesion_img.affine)))

        out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy')
        out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy')
        out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy')

        dict_lesion_out_name[in_name+'vol'] = out_name_1
        dict_lesion_out_name[in_name+'count'] = out_name_2
        dict_lesion_out_name[in_name+'sc'] = out_name_3
        measures_output_filename.extend([out_name_1, out_name_2, out_name_3])

    assert_outputs_exist(parser, args, measures_output_filename)
    if not measures_to_compute:
        parser.error('No connectivity measures were selected, nothing '
                     'to compute.')

    logging.info('The following measures will be computed and save: {}'.format(
        measures_output_filename))

    if args.include_dps:
        if not os.path.isdir(args.include_dps):
            os.makedirs(args.include_dps)
        logging.info('data_per_streamline weighting is activated.')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    if not args.force_labels_list:
        labels_list = np.unique(data_labels)[1:].tolist()
    else:
        labels_list = np.loadtxt(
            args.force_labels_list, dtype=np.int16).tolist()

    comb_list = list(itertools.combinations(labels_list, r=2))
    if not args.no_self_connection:
        comb_list.extend(zip(labels_list, labels_list))

    nbr_cpu = validate_nbr_processes(parser, args)
    measures_dict_list = []
    if nbr_cpu == 1:
        for comb in comb_list:
            measures_dict_list.append(_processing_wrapper([args.in_hdf5,
                                                           img_labels, comb,
                                                           measures_to_compute,
                                                           args.similarity,
                                                           args.density_weighting,
                                                           args.include_dps,
                                                           args.min_lesion_vol]))
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        measures_dict_list = pool.map(_processing_wrapper,
                                      zip(itertools.repeat(args.in_hdf5),
                                          itertools.repeat(img_labels),
                                          comb_list,
                                          itertools.repeat(
                                              measures_to_compute),
                                          itertools.repeat(args.similarity),
                                          itertools.repeat(
                                          args.density_weighting),
                                          itertools.repeat(args.include_dps),
                                          itertools.repeat(args.min_lesion_vol)))
        pool.close()
        pool.join()

    # Removing None entries (combinaisons that do not exist)
    # Fusing the multiprocessing output into a single dictionary
    measures_dict_list = [it for it in measures_dict_list if it is not None]
    if not measures_dict_list:
        raise ValueError('Empty matrix, no entries to save.')
    measures_dict = measures_dict_list[0]
    for dix in measures_dict_list[1:]:
        measures_dict.update(dix)

    if args.no_self_connection:
        total_elem = len(labels_list)**2 - len(labels_list)
        results_elem = len(measures_dict.keys())*2 - len(labels_list)
    else:
        total_elem = len(labels_list)**2
        results_elem = len(measures_dict.keys())*2

    logging.info('Out of {} possible nodes, {} contain value'.format(
        total_elem, results_elem))

    # Filling out all the matrices (symmetric) in the order of labels_list
    nbr_of_measures = len(list(measures_dict.values())[0])
    matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures))

    for in_label, out_label in measures_dict:
        curr_node_dict = measures_dict[(in_label, out_label)]
        measures_ordering = list(curr_node_dict.keys())

        for i, measure in enumerate(curr_node_dict):
            in_pos = labels_list.index(in_label)
            out_pos = labels_list.index(out_label)
            matrix[in_pos, out_pos, i] = curr_node_dict[measure]
            matrix[out_pos, in_pos, i] = curr_node_dict[measure]

    # Saving the matrices separatly with the specified name or dps
    for i, measure in enumerate(measures_ordering):
        if measure == 'volume':
            matrix_basename = args.volume
        elif measure == 'streamline_count':
            matrix_basename = args.streamline_count
        elif measure == 'length':
            matrix_basename = args.length
        elif measure == 'similarity':
            matrix_basename = args.similarity[1]
        elif measure in dict_metrics_out_name:
            matrix_basename = dict_metrics_out_name[measure]
        elif measure in dict_maps_out_name:
            matrix_basename = dict_maps_out_name[measure]
        elif measure in dict_lesion_out_name:
            matrix_basename = dict_lesion_out_name[measure]
        else:
            matrix_basename = measure

        np.save(matrix_basename, matrix[:, :, i])
Esempio n. 9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, args.out_json)

    if (not args.streamlines_measures) and (not args.voxels_measures):
        parser.error('At least one of the two modes is needed')

    nbr_cpu = validate_nbr_processes(parser, args)

    all_binary_metrics = []
    bundles_references_tuple_extended = link_bundles_and_reference(
        parser, args, args.in_bundles)

    if args.streamlines_measures:
        # Gold standard related indices are computed once
        wb_sft = load_tractogram_with_reference(parser, args,
                                                args.streamlines_measures[1])
        wb_sft.to_vox()
        wb_sft.to_corner()
        wb_streamlines = wb_sft.streamlines

        gs_sft = load_tractogram_with_reference(parser, args,
                                                args.streamlines_measures[0])
        gs_sft.to_vox()
        gs_sft.to_corner()
        gs_streamlines = gs_sft.streamlines
        _, gs_dimensions, _, _ = gs_sft.space_attributes

        # Prepare the gold standard only once
        _, gs_streamlines_indices = perform_streamlines_operation(
            intersection, [wb_streamlines, gs_streamlines], precision=0)

        if nbr_cpu == 1:
            streamlines_dict = []
            for i in bundles_references_tuple_extended:
                streamlines_dict.append(
                    compute_streamlines_measures(
                        [i, wb_streamlines, gs_streamlines_indices]))
        else:
            pool = multiprocessing.Pool(nbr_cpu)
            streamlines_dict = pool.map(
                compute_streamlines_measures,
                zip(bundles_references_tuple_extended,
                    itertools.repeat(wb_streamlines),
                    itertools.repeat(gs_streamlines_indices)))
            pool.close()
            pool.join()
        all_binary_metrics.extend(streamlines_dict)

    if not args.voxels_measures:
        gs_binary_3d = compute_tract_counts_map(gs_streamlines, gs_dimensions)
        gs_binary_3d[gs_binary_3d > 0] = 1

        tracking_mask_data = compute_tract_counts_map(wb_streamlines,
                                                      gs_dimensions)
        tracking_mask_data[tracking_mask_data > 0] = 1
    else:
        gs_binary_3d = get_data_as_mask(nib.load(args.voxels_measures[0]))
        gs_binary_3d[gs_binary_3d > 0] = 1
        tracking_mask_data = get_data_as_mask(nib.load(
            args.voxels_measures[1]))
        tracking_mask_data[tracking_mask_data > 0] = 1

    if nbr_cpu == 1:
        voxels_dict = []
        for i in bundles_references_tuple_extended:
            voxels_dict.append(
                compute_voxel_measures([i, tracking_mask_data, gs_binary_3d]))
    else:
        voxels_dict = pool.map(
            compute_voxel_measures,
            zip(bundles_references_tuple_extended,
                itertools.repeat(tracking_mask_data),
                itertools.repeat(gs_binary_3d)))
        pool.close()
        pool.join()
    all_binary_metrics.extend(voxels_dict)

    # After all processing, write the json file and skip None value
    output_binary_dict = {}
    for binary_dict in all_binary_metrics:
        if binary_dict is not None:
            for measure_name in binary_dict.keys():
                if measure_name not in output_binary_dict:
                    output_binary_dict[measure_name] = []
                output_binary_dict[measure_name].append(
                    float(binary_dict[measure_name]))

    with open(args.out_json, 'w') as outfile:
        json.dump(output_binary_dict,
                  outfile,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Esempio n. 10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        args.in_sh,
                        optional=[args.in_bvec, args.in_bval, args.in_b0])
    assert_outputs_exist(parser,
                         args,
                         args.out_sf,
                         optional=[args.out_bvec, args.out_bval])

    if (args.in_bval and not args.out_bval) or (args.out_bval
                                                and not args.in_bval):
        parser.error("--out_bval is required if --in_bval is provided, "
                     "and vice-versa.")

    if args.in_bvec and not args.in_bval:
        parser.error(
            "--in_bval is required when using --in_bvec, in order to remove "
            "bvecs corresponding to b0 images.")

    if args.b0_scaling and not args.in_b0:
        parser.error("--in_b0 is required when using --b0_scaling.")

    nbr_processes = validate_nbr_processes(parser, args)

    # Load SH
    vol_sh = nib.load(args.in_sh)
    data_sh = vol_sh.get_fdata(dtype=np.float32)

    # Sample SF from SH
    if args.sphere:
        sphere = get_sphere(args.sphere)
    elif args.in_bvec:
        bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)
        gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())
        # Remove bvecs corresponding to b0 images
        bvecs = bvecs[np.logical_not(gtab.b0s_mask)]
        sphere = Sphere(xyz=bvecs)

    sf = convert_sh_to_sf(data_sh,
                          sphere,
                          input_basis=args.sh_basis,
                          input_full_basis=args.full_basis,
                          dtype=args.dtype,
                          nbr_processes=nbr_processes)
    new_bvecs = sphere.vertices.astype(np.float32)

    # Assign bval to SF if --in_bval was provided
    new_bvals = []
    if args.in_bval:
        # Load bvals
        bvals, _ = read_bvals_bvecs(args.in_bval, None)

        # Compute average bval
        b0_thr = check_b0_threshold(args.force_b0_threshold, bvals.min(),
                                    bvals.min())
        b0s_mask = bvals <= b0_thr
        avg_bval = np.mean(bvals[np.logical_not(b0s_mask)])

        new_bvals = ([avg_bval] * len(sphere.theta))

    # Add b0 images to SF (and bvals if necessary) if --in_b0 was provided
    if args.in_b0:
        # Load b0
        vol_b0 = nib.load(args.in_b0)
        data_b0 = vol_b0.get_fdata(dtype=args.dtype)
        if data_b0.ndim == 3:
            data_b0 = data_b0[..., np.newaxis]

        new_bvals = ([0] * data_b0.shape[-1]) + new_bvals

        # Append zeros to bvecs
        new_bvecs = np.concatenate((np.zeros(
            (data_b0.shape[-1], 3)), new_bvecs),
                                   axis=0)

        # Scale SF by b0
        if args.b0_scaling:
            # Clip SF signal between 0. and 1., then scale using mean b0
            sf = np.clip(sf, 0., 1.)
            scale_b0 = np.mean(data_b0, axis=-1, keepdims=True)
            sf = sf * scale_b0

        # Append b0 images to SF
        sf = np.concatenate((data_b0, sf), axis=-1)

    # Save new bvals
    if args.out_bval:
        np.savetxt(args.out_bval, np.array(new_bvals)[None, :], fmt='%.3f')

    # Save new bvecs
    if args.out_bvec:
        np.savetxt(args.out_bvec, new_bvecs.T, fmt='%.8f')

    # Save SF
    nib.save(nib.Nifti1Image(sf, vol_sh.affine), args.out_sf)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_bundle)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    streamlines = list(sft.streamlines)
    original_length = len(streamlines)
    logging.debug('Loaded {} streamlines...'.format(original_length))

    nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes)
    pool = multiprocessing.Pool(nbr_cpu)
    timer = time()

    logging.debug('Lauching subsampling on {} processes.'.format(
        args.nbr_processes))
    last_iteration = False
    iter_count = 0
    while True:
        if len(streamlines) < 1000:
            logging.warning('Subsampling less than 1000 streamlines is risky.')
            break
        current_iteration_length = len(streamlines)
        skip = int(len(streamlines) / args.nbr_processes) + 1

        # Cheap trick to avoid duplication in memory, the pop removes from
        # one list to append it to the other, slower but allows bigger bundles
        split_streamlines_list = []
        for _ in range(args.nbr_processes):
            split_streamlines_list.append(streamlines[0:skip])
            del streamlines[0:skip]

        if nbr_cpu == 1:
            resulting_streamlines = []
            for split in split_streamlines_list:
                resulting_streamlines.append(
                    multiprocess_subsampling([
                        split, args.min_distance, args.clustering_thr,
                        args.min_cluster_size
                    ]))
        else:
            resulting_streamlines = pool.map(
                multiprocess_subsampling,
                zip(split_streamlines_list, repeat(args.min_distance),
                    repeat(args.clustering_thr),
                    repeat(args.min_cluster_size)))
            pool.close()
            pool.join()

        # Fused all subprocesses' result together
        streamlines = list(chain(*resulting_streamlines))
        difference_length = current_iteration_length - len(streamlines)
        logging.debug('Difference (before - after): {}'
                      'streamlines were removed'.format(difference_length))

        if last_iteration and difference_length < args.convergence:
            logging.debug('Before ({})-> After ({}),'
                          'total runtime of {} sec.'.format(
                              original_length, len(streamlines),
                              round(time() - timer, 3)))
            break
        elif difference_length < args.convergence or iter_count >= 1000:
            logging.debug('The smart-subsampling converged, below {} '
                          'different streamlines. Adding single-thread'
                          'iteration.'.format(args.convergence))
            args.nbr_processes = 1
            last_iteration = True
        else:
            logging.debug('Threshold of convergence was not achieved.'
                          ' Need another run...\n')
            iter_count += 1
            args.min_cluster_size = 1

            # Once the streamlines reached a low enough amount, switch to
            # single thread for full comparison
            if len(streamlines) < 10000:
                args.nbr_processes = 1
            random.shuffle(streamlines)

    # After convergence, we can simply save the output
    new_sft = StatefulTractogram.from_sft(streamlines, sft)
    save_tractogram(new_sft, args.out_bundle)
Esempio n. 12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser,
                        args.in_sh,
                        optional=[args.in_bvec, args.in_bval, args.in_b0])
    assert_outputs_exist(parser,
                         args,
                         args.out_sf,
                         optional=[args.out_bvec, args.out_bval])

    if (args.in_bval and not args.out_bval) or (args.out_bval
                                                and not args.in_bval):
        parser.error("--out_bval is required if --in_bval is provided, "
                     "and vice-versa.")

    nbr_processes = validate_nbr_processes(parser, args)

    # Load SH
    vol_sh = nib.load(args.in_sh)
    data_sh = vol_sh.get_fdata(dtype=np.float32)

    # Sample SF from SH
    if args.sphere:
        sphere = get_sphere(args.sphere)
    elif args.in_bvec:
        _, bvecs = read_bvals_bvecs(None, args.in_bvec)
        sphere = Sphere(xyz=bvecs)

    sf = convert_sh_to_sf(data_sh,
                          sphere,
                          input_basis=args.sh_basis,
                          input_full_basis=args.full_basis,
                          dtype=args.dtype,
                          nbr_processes=nbr_processes)
    new_bvecs = sphere.vertices.astype(np.float32)

    # Assign bval to SF if --in_bval was provided
    new_bvals = []
    if args.in_bval:
        # Load bvals
        bvals, _ = read_bvals_bvecs(args.in_bval, None)

        # Compute average bval
        check_b0_threshold(args.force_b0_threshold, bvals.min())
        b0s_mask = bvals <= bvals.min()
        avg_bval = np.mean(bvals[np.logical_not(b0s_mask)])

        new_bvals = ([avg_bval] * len(sphere.theta))

    # Add b0 images to SF (and bvals if necessary) if --in_b0 was provided
    if args.in_b0:
        # Load b0
        vol_b0 = nib.load(args.in_b0)
        data_b0 = vol_b0.get_fdata(dtype=args.dtype)
        if data_b0.ndim == 3:
            data_b0 = data_b0[..., np.newaxis]

        new_bvals = ([0] * data_b0.shape[-1]) + new_bvals

        # Append zeros to bvecs
        new_bvecs = np.concatenate((np.zeros(
            (data_b0.shape[-1], 3)), new_bvecs),
                                   axis=0)

        # Append b0 images to SF
        sf = np.concatenate((data_b0, sf), axis=-1)

    # Save new bvals
    if args.out_bval:
        np.savetxt(args.out_bval, np.array(new_bvals)[None, :], fmt='%.3f')

    # Save new bvecs
    if args.out_bvec:
        np.savetxt(args.out_bvec, new_bvecs.T, fmt='%.8f')

    # Save SF
    nib.save(nib.Nifti1Image(sf, vol_sh.affine), args.out_sf)
Esempio n. 13
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, [args.out_json])

    nbr_cpu = validate_nbr_processes(parser, args)

    if not os.path.isdir('tmp_measures/'):
        os.mkdir('tmp_measures/')

    if args.single_compare:
        # Move the single_compare only once, at the end.
        if args.single_compare in args.in_bundles:
            args.in_bundles.remove(args.single_compare)
        bundles_list = args.in_bundles + [args.single_compare]
        bundles_references_tuple_extended = link_bundles_and_reference(
            parser, args, bundles_list)

        single_compare_reference_tuple = bundles_references_tuple_extended.pop()
        comb_dict_keys = list(itertools.product(bundles_references_tuple_extended,
                                                [single_compare_reference_tuple]))
    else:
        bundles_list = args.in_bundles
        # Pre-compute the needed files, to avoid conflict when the number
        # of cpu is higher than the number of bundle
        bundles_references_tuple = link_bundles_and_reference(parser,
                                                              args,
                                                              bundles_list)

        # This approach is only so pytest can run
        if nbr_cpu == 1:
            for i in range(len(bundles_references_tuple)):
                load_data_tmp_saving([bundles_references_tuple[i][0],
                                      bundles_references_tuple[i][1],
                                      True, args.disable_streamline_distance])
        else:
            pool = multiprocessing.Pool(nbr_cpu)
            pool.map(load_data_tmp_saving,
                     zip([tup[0] for tup in bundles_references_tuple],
                         [tup[1] for tup in bundles_references_tuple],
                         itertools.repeat(True),
                         itertools.repeat(args.disable_streamline_distance)))

        comb_dict_keys = list(itertools.combinations(
            bundles_references_tuple, r=2))

    if nbr_cpu == 1:
        all_measures_dict = []
        for i in comb_dict_keys:
            all_measures_dict.append(compute_all_measures([
                i, args.streamline_dice, args.disable_streamline_distance]))
    else:
        all_measures_dict = pool.map(
            compute_all_measures,
            zip(comb_dict_keys,
                itertools.repeat(
                    args.streamline_dice),
                itertools.repeat(args.disable_streamline_distance)))
        pool.close()
        pool.join()

    output_measures_dict = {}
    for measure_dict in all_measures_dict:
        # Empty bundle should not make the script crash
        if measure_dict is not None:
            for measure_name in measure_dict.keys():
                # Create an empty list first
                if measure_name not in output_measures_dict:
                    output_measures_dict[measure_name] = []
                output_measures_dict[measure_name].append(
                    float(measure_dict[measure_name]))

    with open(args.out_json, 'w') as outfile:
        json.dump(output_measures_dict, outfile,
                  indent=args.indent, sort_keys=args.sort_keys)

    if not args.keep_tmp:
        shutil.rmtree('tmp_measures/')
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.gfa = args.gfa or 'gfa.nii.gz'
        args.peaks = args.peaks or 'peaks.nii.gz'
        args.peak_indices = args.peak_indices or 'peaks_indices.nii.gz'
        args.sh = args.sh or 'sh.nii.gz'
        args.nufo = args.nufo or 'nufo.nii.gz'
        args.a_power = args.a_power or 'anisotropic_power.nii.gz'

    arglist = [
        args.gfa, args.peaks, args.peak_indices, args.sh, args.nufo,
        args.a_power
    ]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec])
    assert_outputs_exist(parser, args, arglist)
    validate_nbr_processes(parser, args)

    nbr_processes = args.nbr_processes
    parallel = nbr_processes > 1

    # Load data
    img = nib.load(args.in_dwi)
    data = img.get_fdata(dtype=np.float32)

    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    sphere = get_sphere('symmetric724')

    mask = None
    if args.mask:
        mask = get_data_as_mask(nib.load(args.mask))

        # Sanity check on shape of mask
        if mask.shape != data.shape[:-1]:
            raise ValueError('Mask shape does not match data shape.')

    if args.use_qball:
        model = QballModel(gtab, sh_order=args.sh_order, smooth=DEFAULT_SMOOTH)
    else:
        model = CsaOdfModel(gtab,
                            sh_order=args.sh_order,
                            smooth=DEFAULT_SMOOTH)

    odfpeaks = peaks_from_model(model=model,
                                data=data,
                                sphere=sphere,
                                relative_peak_threshold=.5,
                                min_separation_angle=25,
                                mask=mask,
                                return_odf=False,
                                normalize_peaks=True,
                                return_sh=True,
                                sh_order=int(args.sh_order),
                                sh_basis_type=args.sh_basis,
                                npeaks=5,
                                parallel=parallel,
                                nbr_processes=nbr_processes)

    if args.gfa:
        nib.save(nib.Nifti1Image(odfpeaks.gfa.astype(np.float32), img.affine),
                 args.gfa)

    if args.peaks:
        nib.save(
            nib.Nifti1Image(reshape_peaks_for_visualization(odfpeaks),
                            img.affine), args.peaks)

    if args.peak_indices:
        nib.save(nib.Nifti1Image(odfpeaks.peak_indices, img.affine),
                 args.peak_indices)

    if args.sh:
        nib.save(
            nib.Nifti1Image(odfpeaks.shm_coeff.astype(np.float32), img.affine),
            args.sh)

    if args.nufo:
        peaks_count = (odfpeaks.peak_indices > -1).sum(3)
        nib.save(nib.Nifti1Image(peaks_count.astype(np.int32), img.affine),
                 args.nufo)

    if args.a_power:
        odf_a_power = anisotropic_power(odfpeaks.shm_coeff)
        nib.save(nib.Nifti1Image(odf_a_power.astype(np.float32), img.affine),
                 args.a_power)
Esempio n. 15
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, args.out_json)

    nbr_cpu = validate_nbr_processes(parser, args)
    bundles_references_tuple_extended = link_bundles_and_reference(
        parser, args, args.in_bundles)

    if nbr_cpu == 1:
        all_measures_dict = []
        for i in bundles_references_tuple_extended:
            all_measures_dict.append(compute_measures(i))
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        all_measures_dict = pool.map(compute_measures,
                                     bundles_references_tuple_extended)
        pool.close()
        pool.join()

    output_measures_dict = {}
    for measure_dict in all_measures_dict:
        # Empty bundle should not make the script crash
        if measure_dict is not None:
            for measure_name in measure_dict.keys():
                # Create an empty list first
                if measure_name not in output_measures_dict:
                    output_measures_dict[measure_name] = []
                output_measures_dict[measure_name].append(
                    measure_dict[measure_name])
    # add group stats if user wants
    if args.group_statistics:
        # length and span are weighted by streamline count
        group_total_length = np.sum(
            np.multiply(output_measures_dict['avg_length'],
                        output_measures_dict['streamlines_count']))
        group_total_span = np.sum(
            np.multiply(output_measures_dict['span'],
                        output_measures_dict['streamlines_count']))
        group_streamlines_count = \
            np.sum(output_measures_dict['streamlines_count'])
        group_avg_length = group_total_length / group_streamlines_count
        group_avg_span = group_total_span / group_streamlines_count
        group_avg_vol = np.average(output_measures_dict['volume'])
        group_avg_diam = \
            2 * np.sqrt(group_avg_vol / (np.pi * group_avg_length))
        output_measures_dict['group_stats'] = {}
        output_measures_dict['group_stats']['total_streamlines_count'] = \
            float(group_streamlines_count)
        output_measures_dict['group_stats']['avg_streamline_length'] = \
            group_avg_length
        # max and min length of all streamlines in all input bundles
        output_measures_dict['group_stats']['max_streamline_length'] = \
            float(np.max(output_measures_dict['max_length']))
        output_measures_dict['group_stats']['min_streamline_length'] = \
            float(np.min(output_measures_dict['min_length']))
        output_measures_dict['group_stats']['avg_streamline_span'] = \
            group_avg_span
        # computed with other set averages and not weighted by streamline count
        output_measures_dict['group_stats']['avg_volume'] = group_avg_vol
        output_measures_dict['group_stats']['avg_curl'] = \
            group_avg_length / group_avg_span
        output_measures_dict['group_stats']['avg_diameter'] = group_avg_diam
        output_measures_dict['group_stats']['avg_elongation'] = \
            group_avg_length / group_avg_diam
        output_measures_dict['group_stats']['avg_surface_area'] = \
            np.average(output_measures_dict['surface_area'])
        output_measures_dict['group_stats']['avg_irreg'] = \
            np.average(output_measures_dict['irregularity'])
        output_measures_dict['group_stats']['avg_end_surface_area_head'] = \
            np.average(output_measures_dict['end_surface_area_head'])
        output_measures_dict['group_stats']['avg_end_surface_area_tail'] = \
            np.average(output_measures_dict['end_surface_area_tail'])
        output_measures_dict['group_stats']['avg_radius_head'] = \
            np.average(output_measures_dict['radius_head'])
        output_measures_dict['group_stats']['avg_radius_tail'] = \
            np.average(output_measures_dict['radius_tail'])
        output_measures_dict['group_stats']['avg_irregularity_head'] = \
            np.average(
                output_measures_dict['irregularity_of_end_surface_head'])
        output_measures_dict['group_stats']['avg_irregularity_tail'] = \
            np.average(
                output_measures_dict['irregularity_of_end_surface_tail'])
        output_measures_dict['group_stats']['avg_fractal_dimension'] = \
            np.average(output_measures_dict['fractal_dimension'])
    with open(args.out_json, 'w') as outfile:
        json.dump(output_measures_dict,
                  outfile,
                  indent=args.indent,
                  sort_keys=args.sort_keys)