Beispiel #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    required = args.in_label
    assert_inputs_exist(parser, required)

    label_img = nib.load(args.in_label)
    label_img_data = get_data_as_label(label_img)

    if args.scilpy_lut:
        with open(os.path.join(get_lut_dir(), args.scilpy_lut + '.json')) as f:
            label_dict = json.load(f)
        (label_indices, label_names) = zip(*label_dict.items())
    else:
        with open(args.custom_lut) as f:
            label_dict = json.load(f)
        (label_indices, label_names) = zip(*label_dict.items())

    output_filenames = []
    for label, name in zip(label_indices, label_names):
        if int(label) != 0:
            if args.out_prefix:
                output_filenames.append(os.path.join(args.out_dir,
                                                     '{0}_{1}.nii.gz'.format(
                                                         args.out_prefix,
                                                         name)))
            else:
                output_filenames.append(os.path.join(args.out_dir,
                                                     '{0}.nii.gz'.format(
                                                        name)))

    assert_outputs_exist(parser, args, output_filenames)

    if args.out_dir and not os.path.isdir(args.out_dir):
        os.mkdir(args.out_dir)

    # Extract the voxels that match the label and save them to a file.
    cnt_filename = 0
    for label in label_indices:
        if int(label) != 0:
            split_label = np.zeros(label_img.shape,
                                   dtype=np.uint16)
            split_label[np.where(label_img_data == int(label))] = label

            split_image = nib.Nifti1Image(split_label,
                                          label_img.affine,
                                          header=label_img.header)
            nib.save(split_image, output_filenames[cnt_filename])
            cnt_filename += 1
def main():

    parser = _build_arg_parser()
    args = parser.parse_args()

    required = args.in_labels
    assert_inputs_exist(parser, required)

    label_img = nib.load(args.in_labels)
    label_img_data = get_data_as_label(label_img)

    if args.range:
        label_indices = [item for sublist in args.range for item in sublist]
    else:
        label_indices = np.unique(label_img_data)
    label_names = [str(i) for i in label_indices]

    output_filenames = []
    for label, name in zip(label_indices, label_names):
        if int(label) != 0:
            if args.out_prefix:
                output_filenames.append(os.path.join(args.out_dir,
                                                     '{0}_{1}.nii.gz'.format(
                                                         args.out_prefix,
                                                         name)))
            else:
                output_filenames.append(os.path.join(args.out_dir,
                                                     '{0}.nii.gz'.format(
                                                        name)))

    assert_outputs_exist(parser, args, output_filenames)

    if args.out_dir and not os.path.isdir(args.out_dir):
        os.mkdir(args.out_dir)

    # Extract the voxels that match the label and save them to a file.
    cnt_filename = 0
    for label in label_indices:
        if int(label) != 0:
            split_label = np.zeros(label_img.shape,
                                   dtype=np.uint16)
            split_label[np.where(label_img_data == int(label))] = label

            split_image = nib.Nifti1Image(split_label,
                                          label_img.affine,
                                          header=label_img.header)
            nib.save(split_image, output_filenames[cnt_filename])
            cnt_filename += 1
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    required = args.in_labels, args.in_seed_maps, args.in_labels_lut
    assert_inputs_exist(parser, required)

    # Load atlas image
    label_img = nib.load(args.in_labels)
    label_img_data = get_data_as_label(label_img)

    # Load atlas lut
    with open(args.in_labels_lut) as f:
        label_dict = json.load(f)
    (label_indices, label_names) = zip(*label_dict.items())

    # Load seed image
    seed_img = nib.load(args.in_seed_maps)
    seed_img_data = seed_img.get_fdata(dtype=np.float32)

    for label, name in zip(label_indices, label_names):
        label = int(label)
        if label != 0:
            curr_data = (seed_img_data[np.where(label_img_data == label)])
            nb_vx_roi = np.count_nonzero(label_img_data == label)
            nb_seed_vx = np.count_nonzero(curr_data)

            if nb_seed_vx != 0:
                mean_seed = np.sum(curr_data) / nb_seed_vx
                max_seed = np.max(curr_data)
                std_seed = np.sqrt(
                    np.mean(abs(curr_data[curr_data != 0] - mean_seed)**2))

                print(
                    json.dumps({
                        'ROI-idx': label,
                        'ROI-name': str(name),
                        'nb-vx-roi': int(nb_vx_roi),
                        'nb-vx-seed': int(nb_seed_vx),
                        'max': int(max_seed),
                        'mean': float(mean_seed),
                        'std': float(std_seed)
                    }))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.voxel_label_map)

    voxel_label_map_img = nib.load(args.voxel_label_map)
    voxel_label_map_data = get_data_as_label(voxel_label_map_img)
    voxel_size = voxel_label_map_img.header['pixdim'][1:4]

    labels = np.unique(voxel_label_map_data.astype(np.uint8))[1:]
    num_digits_labels = len(str(np.max(labels)))
    voxel_volume = np.prod(voxel_size)
    stats = {args.bundle_name: {'volume': {}}}
    for i in labels:
        stats[args.bundle_name]['volume']['{}'.format(i)
                                              .zfill(num_digits_labels)] =\
            len(voxel_label_map_data[voxel_label_map_data == i]) * voxel_volume

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Beispiel #5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_labels)
    assert_outputs_exist(parser, args, args.out_labels)

    # Load volume
    label_img = nib.load(args.in_labels)
    labels_volume = get_data_as_label(label_img)

    # Remove given labels from the volume
    for index in np.unique(args.indices):
        mask = labels_volume == index
        labels_volume[mask] = args.background
        if np.count_nonzero(mask) == 0:
            logging.warning("Label {} was not in the volume".format(index))

    # Save final volume
    nii = nib.Nifti1Image(labels_volume, volume_img.affine, volume_img.header)
    nib.save(nii, args.out_labels)
Beispiel #6
0
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]
    min_lesion_vol = args[7]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
    if len(streamlines) == 0:
        return

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy))*voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute or
              'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute and
              'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines,
                                           dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory,
                                      in_label, out_label,
                                      labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        # Maps
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname,
                                       in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple):
            if not isinstance(measure[0], tuple) \
                    and os.path.isfile(measure[0]):
                metric_filename = measure[0]
                metric_img = measure[1]
                if not is_header_compatible(metric_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        metric_filename))
                    raise IOError

                metric_data = metric_img.get_fdata(dtype=np.float64)
                if weighted:
                    avg_value = np.average(metric_data, weights=density)
                else:
                    avg_value = np.average(metric_data[density > 0])
                measures_to_return[metric_filename] = avg_value
            # lesion
            else:
                lesion_filename = measure[0][0]
                computed_lesion_labels = measure[0][1]
                lesion_img = measure[1]
                if not is_header_compatible(lesion_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        lesion_filename))
                    raise IOError

                voxel_sizes = lesion_img.header.get_zooms()[0:3]
                lesion_img.set_filename('tmp.nii.gz')
                lesion_atlas = get_data_as_label(lesion_img)
                tmp_dict = compute_lesion_stats(
                    density.astype(bool), lesion_atlas,
                    voxel_sizes=voxel_sizes, single_label=True,
                    min_lesion_vol=min_lesion_vol,
                    precomputed_lesion_labels=computed_lesion_labels)

                tmp_ind = _streamlines_in_mask(list(streamlines),
                                               lesion_atlas.astype(np.uint8),
                                               np.eye(3), [0, 0, 0])
                streamlines_count = len(
                    np.where(tmp_ind == [0, 1][True])[0].tolist())

                if tmp_dict:
                    measures_to_return[lesion_filename+'vol'] = \
                        tmp_dict['lesion_total_volume']
                    measures_to_return[lesion_filename+'count'] = \
                        tmp_dict['lesion_count']
                    measures_to_return[lesion_filename+'sc'] = \
                        streamlines_count
                else:
                    measures_to_return[lesion_filename+'vol'] = 0
                    measures_to_return[lesion_filename+'count'] = 0
                    measures_to_return[lesion_filename+'sc'] = 0

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                if 'commit' in dps_key:
                    measures_to_return[out_file] = np.sum(
                        hdf5_file[key][dps_key])
                else:
                    measures_to_return[out_file] = np.average(
                        hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
Beispiel #7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_labels],
                        args.force_labels_list)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)

    measures_to_compute = []
    measures_output_filename = []
    if args.volume:
        measures_to_compute.append('volume')
        measures_output_filename.append(args.volume)
    if args.streamline_count:
        measures_to_compute.append('streamline_count')
        measures_output_filename.append(args.streamline_count)
    if args.length:
        measures_to_compute.append('length')
        measures_output_filename.append(args.length)
    if args.similarity:
        measures_to_compute.append('similarity')
        measures_output_filename.append(args.similarity[1])

    dict_maps_out_name = {}
    if args.maps is not None:
        for in_folder, out_name in args.maps:
            measures_to_compute.append(in_folder)
            dict_maps_out_name[in_folder] = out_name
            measures_output_filename.append(out_name)

    dict_metrics_out_name = {}
    if args.metrics is not None:
        for in_name, out_name in args.metrics:
            # Verify that all metrics are compatible with each other
            if not is_header_compatible(args.metrics[0][0], in_name):
                raise IOError('Metrics {} and  {} do not share a compatible '
                              'header'.format(args.metrics[0][0], in_name))

            # This is necessary to support more than one map for weighting
            measures_to_compute.append((in_name, nib.load(in_name)))
            dict_metrics_out_name[in_name] = out_name
            measures_output_filename.append(out_name)

    dict_lesion_out_name = {}
    if args.lesion_load is not None:
        in_name = args.lesion_load[0]
        lesion_img = nib.load(in_name)
        lesion_data = get_data_as_mask(lesion_img, dtype=bool)
        lesion_atlas, _ = ndi.label(lesion_data)
        measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]),
                                    nib.Nifti1Image(lesion_atlas,
                                                    lesion_img.affine)))

        out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy')
        out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy')
        out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy')

        dict_lesion_out_name[in_name+'vol'] = out_name_1
        dict_lesion_out_name[in_name+'count'] = out_name_2
        dict_lesion_out_name[in_name+'sc'] = out_name_3
        measures_output_filename.extend([out_name_1, out_name_2, out_name_3])

    assert_outputs_exist(parser, args, measures_output_filename)
    if not measures_to_compute:
        parser.error('No connectivity measures were selected, nothing '
                     'to compute.')

    logging.info('The following measures will be computed and save: {}'.format(
        measures_output_filename))

    if args.include_dps:
        if not os.path.isdir(args.include_dps):
            os.makedirs(args.include_dps)
        logging.info('data_per_streamline weighting is activated.')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    if not args.force_labels_list:
        labels_list = np.unique(data_labels)[1:].tolist()
    else:
        labels_list = np.loadtxt(
            args.force_labels_list, dtype=np.int16).tolist()

    comb_list = list(itertools.combinations(labels_list, r=2))
    if not args.no_self_connection:
        comb_list.extend(zip(labels_list, labels_list))

    nbr_cpu = validate_nbr_processes(parser, args)
    measures_dict_list = []
    if nbr_cpu == 1:
        for comb in comb_list:
            measures_dict_list.append(_processing_wrapper([args.in_hdf5,
                                                           img_labels, comb,
                                                           measures_to_compute,
                                                           args.similarity,
                                                           args.density_weighting,
                                                           args.include_dps,
                                                           args.min_lesion_vol]))
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        measures_dict_list = pool.map(_processing_wrapper,
                                      zip(itertools.repeat(args.in_hdf5),
                                          itertools.repeat(img_labels),
                                          comb_list,
                                          itertools.repeat(
                                              measures_to_compute),
                                          itertools.repeat(args.similarity),
                                          itertools.repeat(
                                          args.density_weighting),
                                          itertools.repeat(args.include_dps),
                                          itertools.repeat(args.min_lesion_vol)))
        pool.close()
        pool.join()

    # Removing None entries (combinaisons that do not exist)
    # Fusing the multiprocessing output into a single dictionary
    measures_dict_list = [it for it in measures_dict_list if it is not None]
    if not measures_dict_list:
        raise ValueError('Empty matrix, no entries to save.')
    measures_dict = measures_dict_list[0]
    for dix in measures_dict_list[1:]:
        measures_dict.update(dix)

    if args.no_self_connection:
        total_elem = len(labels_list)**2 - len(labels_list)
        results_elem = len(measures_dict.keys())*2 - len(labels_list)
    else:
        total_elem = len(labels_list)**2
        results_elem = len(measures_dict.keys())*2

    logging.info('Out of {} possible nodes, {} contain value'.format(
        total_elem, results_elem))

    # Filling out all the matrices (symmetric) in the order of labels_list
    nbr_of_measures = len(list(measures_dict.values())[0])
    matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures))

    for in_label, out_label in measures_dict:
        curr_node_dict = measures_dict[(in_label, out_label)]
        measures_ordering = list(curr_node_dict.keys())

        for i, measure in enumerate(curr_node_dict):
            in_pos = labels_list.index(in_label)
            out_pos = labels_list.index(out_label)
            matrix[in_pos, out_pos, i] = curr_node_dict[measure]
            matrix[out_pos, in_pos, i] = curr_node_dict[measure]

    # Saving the matrices separatly with the specified name or dps
    for i, measure in enumerate(measures_ordering):
        if measure == 'volume':
            matrix_basename = args.volume
        elif measure == 'streamline_count':
            matrix_basename = args.streamline_count
        elif measure == 'length':
            matrix_basename = args.length
        elif measure == 'similarity':
            matrix_basename = args.similarity[1]
        elif measure in dict_metrics_out_name:
            matrix_basename = dict_metrics_out_name[measure]
        elif measure in dict_maps_out_name:
            matrix_basename = dict_maps_out_name[measure]
        elif measure in dict_lesion_out_name:
            matrix_basename = dict_lesion_out_name[measure]
        else:
            matrix_basename = measure

        np.save(matrix_basename, matrix[:, :, i])
Beispiel #8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Beispiel #9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    roi_opt_list, only_filtering_list = prepare_filtering_list(parser, args)
    o_dict = {}

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    # Streamline count before filtering
    o_dict['streamline_count_before_filtering'] = len(sft.streamlines)

    for i, roi_opt in enumerate(roi_opt_list):
        curr_dict = {}
        # Atlas needs an extra argument (value in the LUT)
        if roi_opt[0] == 'atlas_roi':
            filter_type, filter_arg, filter_arg_2, \
                filter_mode, filter_criteria = roi_opt
        else:
            filter_type, filter_arg, filter_mode, filter_criteria = roi_opt

        curr_dict['filename'] = os.path.abspath(filter_arg)
        curr_dict['type'] = filter_type
        curr_dict['mode'] = filter_mode
        curr_dict['criteria'] = filter_criteria

        is_exclude = False if filter_criteria == 'include' else True

        if filter_type == 'drawn_roi' or filter_type == 'atlas_roi':
            img = nib.load(filter_arg)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')
            if filter_type == 'drawn_roi':
                mask = get_data_as_mask(img)
            else:
                atlas = get_data_as_label(img)
                mask = np.zeros(atlas.shape, dtype=np.uint16)
                mask[atlas == int(filter_arg_2)] = 1
            filtered_sft, indexes = filter_grid_roi(sft, mask,
                                                    filter_mode, is_exclude)

        # For every case, the input number must be greater or equal to 0 and
        # below the dimension, since this is a voxel space operation
        elif filter_type in ['x_plane', 'y_plane', 'z_plane']:
            filter_arg = int(filter_arg)
            _, dim, _, _ = sft.space_attributes
            mask = np.zeros(dim, dtype=np.int16)
            error_msg = None
            if filter_type == 'x_plane':
                if 0 <= filter_arg < dim[0]:
                    mask[filter_arg, :, :] = 1
                else:
                    error_msg = 'X plane ' + str(filter_arg)

            elif filter_type == 'y_plane':
                if 0 <= filter_arg < dim[1]:
                    mask[:, filter_arg, :] = 1
                else:
                    error_msg = 'Y plane ' + str(filter_arg)

            elif filter_type == 'z_plane':
                if 0 <= filter_arg < dim[2]:
                    mask[:, :, filter_arg] = 1
                else:
                    error_msg = 'Z plane ' + str(filter_arg)

            if error_msg:
                parser.error('{} is not valid according to the '
                             'tractogram header.'.format(error_msg))

            filtered_sft, indexes = filter_grid_roi(sft, mask,
                                                    filter_mode, is_exclude)

        elif filter_type == 'bdo':
            geometry, radius, center = read_info_from_mb_bdo(filter_arg)
            if geometry == 'Ellipsoid':
                filtered_sft, indexes = filter_ellipsoid(sft,
                                                         radius, center,
                                                         filter_mode, is_exclude)
            elif geometry == 'Cuboid':
                filtered_sft, indexes = filter_cuboid(sft,
                                                      radius, center,
                                                      filter_mode, is_exclude)

        logging.debug('The filtering options {0} resulted in '
                      '{1} streamlines'.format(roi_opt, len(filtered_sft)))

        sft = filtered_sft

        if only_filtering_list:
            filtering_Name = 'Filter_' + str(i)
            curr_dict['streamline_count_after_filtering'] = len(sft.streamlines)
            o_dict[filtering_Name] = curr_dict

    # Streamline count after filtering
    o_dict['streamline_count_final_filtering'] = len(sft.streamlines)
    if args.display_counts:
        print(json.dumps(o_dict, indent=args.indent))

    if not filtered_sft:
        if args.no_empty:
            logging.debug("The file {} won't be written (0 streamline)".format(
                args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(sft, args.out_tractogram)
Beispiel #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_matrix,
                        [args.length, args.inverse_length, args.bundle_volume])
    assert_outputs_exist(parser, args, args.out_matrix)

    in_matrix = load_matrix_in_any_format(args.in_matrix)

    # Parcel volume and surface normalization require the atlas
    # This script should be used directly after scil_decompose_connectivity.py
    if args.parcel_volume or args.parcel_surface:
        atlas_tuple = args.parcel_volume if args.parcel_volume \
            else args.parcel_surface
        atlas_filepath, labels_filepath = atlas_tuple
        assert_inputs_exist(parser, [atlas_filepath, labels_filepath])

        atlas_img = nib.load(atlas_filepath)
        atlas_data = get_data_as_label(atlas_img)

        voxels_size = atlas_img.header.get_zooms()[:3]
        if voxels_size[0] != voxels_size[1] \
           or voxels_size[0] != voxels_size[2]:
            parser.error('Atlas must have an isotropic resolution.')

        voxels_vol = np.prod(atlas_img.header.get_zooms()[:3])
        voxels_sur = np.prod(atlas_img.header.get_zooms()[:2])

        # Excluding background (0)
        labels_list = np.loadtxt(labels_filepath)
        if len(labels_list) != in_matrix.shape[0] \
                and len(labels_list) != in_matrix.shape[1]:
            parser.error('Atlas should have the same number of label as the '
                         'input matrix.')

    # Normalization can be combined together
    out_matrix = in_matrix
    if args.length:
        length_mat = load_matrix_in_any_format(args.length)
        out_matrix[length_mat > 0] *= length_mat[length_mat > 0]
    elif args.inverse_length:
        length_mat = load_matrix_in_any_format(args.inverse_length)
        out_matrix[length_mat > 0] /= length_mat[length_mat > 0]

    if args.bundle_volume:
        volume_mat = load_matrix_in_any_format(args.bundle_volume)
        out_matrix[volume_mat > 0] /= volume_mat[volume_mat > 0]

    # Node-wise computation are necessary for this type of normalize
    if args.parcel_volume or args.parcel_surface:
        out_matrix = copy(in_matrix)
        pos_list = range(len(labels_list))
        all_comb = list(itertools.combinations(pos_list, r=2))
        all_comb.extend(zip(pos_list, pos_list))

        # Prevent useless computions for approximate_surface_node()
        factor_list = []
        for label in labels_list:
            if args.parcel_volume:
                factor_list.append(
                    np.count_nonzero(atlas_data == label) * voxels_vol)
            else:
                if np.count_nonzero(atlas_data == label):
                    roi = np.zeros(atlas_data.shape)
                    roi[atlas_data == label] = 1
                    factor_list.append(
                        approximate_surface_node(roi) * voxels_sur)
                else:
                    factor_list.append(0)

        for pos_1, pos_2 in all_comb:
            factor = factor_list[pos_1] + factor_list[pos_2]
            if abs(factor) > 0.001:
                out_matrix[pos_1, pos_2] /= factor
                out_matrix[pos_2, pos_1] /= factor

    # Load as image
    ref_matrix = nib.Nifti1Image(in_matrix, np.eye(4))
    # Simple scaling of the whole matrix, facilitate comparison across subject
    if args.max_at_one:
        out_matrix = nib.Nifti1Image(out_matrix, np.eye(4))
        out_matrix = normalize_max([out_matrix], ref_matrix)
    elif args.sum_to_one:
        out_matrix = nib.Nifti1Image(out_matrix, np.eye(4))
        out_matrix = normalize_sum([out_matrix], ref_matrix)
    elif args.log_10:
        out_matrix = nib.Nifti1Image(out_matrix, np.eye(4))
        out_matrix = base_10_log([out_matrix], ref_matrix)

    save_matrix_in_any_format(args.out_matrix, out_matrix)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_x_map, args.in_y_map])

    if args.out_dir is None:
        args.out_dir = './'

    if args.in_atlas:
        assert_inputs_exist(parser, args.atlas_lut)

    # Load x and y images
    maps = [args.in_x_map, args.in_y_map]

    maps_data = []
    for curr_map in maps:
        maps_image = nib.load(curr_map)
        if args.not_exclude_zero:
            maps_data.append(maps_image.get_fdata(dtype=np.float32))
        else:
            data = maps_image.get_fdata(dtype=np.float32)
            data[np.where(data == 0)] = np.nan
            maps_data.append(data)

    if args.in_bin_mask:
        if args.label is None:
            args.label = 'Masking data'
        # Load and apply binary mask
        mask_image = nib.load(args.in_bin_mask)
        mask_data = get_data_as_mask(mask_image)
        for curr_map in maps_data:
            curr_map[np.where(mask_data == 0)] = np.nan

    if args.in_prob_maps:
        if args.label is None:
            args.label = 'Threshold prob_map 1'
        # Load tissue probability maps
        prob_masks = []
        for curr_map in args.in_prob_maps:
            prob_image = nib.load(curr_map)
            prob_masks.append(prob_image.get_fdata(dtype=np.float32))

        # Deepcopy to apply the second probability map on same data
        maps_prob = copy.deepcopy(maps_data)

        # Threshold probability images with tissue probability maps
        for curr_map in maps_data:
            curr_map[np.where(prob_masks[0] < args.thr)] = np.nan

        for curr_map in maps_prob:
            curr_map[np.where(prob_masks[1] < args.thr)] = np.nan

    if args.in_atlas:
        label_image = nib.load(args.in_atlas)
        label_data = get_data_as_label(label_image)

        with open(args.atlas_lut) as f:
            label_dict = json.load(f)
        lut_indices, lut_names = zip(*label_dict.items())

        if args.specific_label:
            label_indices = []
            label_names = []
            for key in args.specific_label:
                label_indices.append(lut_indices[key - 1])
                label_names.append(lut_names[key - 1])
        else:
            (label_indices, label_names) = (lut_indices, lut_names)

    # Scatter Plots
    # Plot for each label only with unmasking data
    if args.in_atlas:
        if args.in_folder:
            args.out_dir = os.path.join(args.out_dir, 'Label_plots/')
            if not os.path.isdir(args.out_dir):
                os.mkdir(args.out_dir)

        for label, name in zip(label_indices, label_names):
            label = int(label)
            fig, ax = plt.subplots()
            x = (maps_data[0][np.where(label_data == label)])
            y = (maps_data[1][np.where(label_data == label)])

            ax.scatter(x,
                       y,
                       label=name,
                       color=args.colors[0],
                       s=args.marker_size,
                       marker=args.marker,
                       alpha=args.transparency)
            plt.xlabel(args.x_label)
            plt.ylabel(args.y_label)
            plt.title(args.title)
            plt.legend()

            out_name = os.path.join(args.out_dir + args.out_name + '_' + name +
                                    '.png')
            plt.savefig(out_name, dpi=args.dpi, bbox_inches='tight')
            plt.close()

    else:
        # Plot unmasking or masking data (by binary or first probability map)
        fig, ax = plt.subplots()
        plt.xlabel(args.x_label)
        plt.ylabel(args.y_label)
        plt.title(args.title)

        ax.scatter(maps_data[0],
                   maps_data[1],
                   label=args.label,
                   color=args.colors[0],
                   s=args.marker_size,
                   marker=args.marker,
                   alpha=args.transparency)

        # Add data thresholded with the second probability map
        if args.in_prob_maps:
            ax.scatter(maps_prob[0],
                       maps_prob[1],
                       label=args.label_prob,
                       color=args.colors[1],
                       s=args.marker_size,
                       marker=args.marker,
                       alpha=args.transparency)

        plt.legend()

        if args.show_only:
            plt.show()
        else:
            plt.savefig(os.path.join(args.out_dir, args.out_name),
                        dpi=args.dpi,
                        bbox_inches='tight')
Beispiel #12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if (not args.bundle) and (not args.bundle_mask) \
            and (not args.bundle_labels_map):
        parser.error('One of the option --bundle or --map must be used')

    assert_inputs_exist(parser, [args.in_lesion],
                        optional=[args.bundle, args.bundle_mask,
                                  args.bundle_labels_map])
    assert_outputs_exist(parser, args, args.out_json,
                         optional=[args.out_lesion_stats,
                                   args.out_streamlines_stats])

    lesion_img = nib.load(args.in_lesion)
    lesion_data = get_data_as_mask(lesion_img, dtype=bool)

    if args.bundle:
        bundle_name, _ = split_name_with_nii(os.path.basename(args.bundle))
        sft = load_tractogram_with_reference(parser, args, args.bundle)
        sft.to_vox()
        sft.to_corner()
        streamlines = sft.get_streamlines_copy()
        map_data = compute_tract_counts_map(streamlines,
                                            lesion_data.shape)
        map_data[map_data > 0] = 1
    elif args.bundle_mask:
        bundle_name, _ = split_name_with_nii(
            os.path.basename(args.bundle_mask))
        map_img = nib.load(args.bundle_mask)
        map_data = get_data_as_mask(map_img)
    else:
        bundle_name, _ = split_name_with_nii(os.path.basename(
            args.bundle_labels_map))
        map_img = nib.load(args.bundle_labels_map)
        map_data = get_data_as_label(map_img)

    is_single_label = args.bundle_labels_map is None
    voxel_sizes = lesion_img.header.get_zooms()[0:3]
    lesion_atlas, _ = ndi.label(lesion_data)

    lesion_load_dict = compute_lesion_stats(
        map_data, lesion_atlas, single_label=is_single_label,
        voxel_sizes=voxel_sizes, min_lesion_vol=args.min_lesion_vol)

    if args.out_lesion_atlas:
        lesion_atlas *= map_data.astype(bool)
        nib.save(nib.Nifti1Image(lesion_atlas, lesion_img.affine),
                 args.out_lesion_atlas)

    volume_dict = {bundle_name: lesion_load_dict}
    with open(args.out_json, 'w') as outfile:
        json.dump(volume_dict, outfile,
                  sort_keys=args.sort_keys, indent=args.indent)

    if args.out_streamlines_stats or args.out_lesion_stats:
        lesion_dict = {}
        for lesion in np.unique(lesion_atlas)[1:]:
            curr_vol = np.count_nonzero(lesion_atlas[lesion_atlas == lesion]) \
                * np.prod(voxel_sizes)
            if curr_vol >= args.min_lesion_vol:
                key = str(lesion).zfill(4)
                lesion_dict[key] = {'volume': curr_vol}
                if args.bundle:
                    tmp = np.zeros(lesion_atlas.shape)
                    tmp[lesion_atlas == lesion] = 1
                    new_sft, _ = filter_grid_roi(sft, tmp, 'any', False)
                    lesion_dict[key]['strs_count'] = len(new_sft)

        lesion_vol_dict = {bundle_name: {}}
        streamlines_count_dict = {bundle_name: {'streamlines_count': {}}}
        for key in lesion_dict.keys():
            lesion_vol_dict[bundle_name][key] = lesion_dict[key]['volume']
            if args.bundle:
                streamlines_count_dict[bundle_name]['streamlines_count'][key] = \
                    lesion_dict[key]['strs_count']

        if args.out_lesion_stats:
            with open(args.out_lesion_stats, 'w') as outfile:
                json.dump(lesion_vol_dict, outfile,
                          sort_keys=args.sort_keys, indent=args.indent)
        if args.out_streamlines_stats:
            with open(args.out_streamlines_stats, 'w') as outfile:
                json.dump(streamlines_count_dict, outfile,
                          sort_keys=args.sort_keys, indent=args.indent)
Beispiel #13
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    image_files = []
    indices_per_volume = []
    # Separate argument per volume
    used_indices_all = False
    for v_args in args.volume_ids:
        if len(v_args) < 2:
            parser.error("No indices was given for a given volume.")

        image_files.append(v_args[0])
        if "all" in v_args:
            used_indices_all = True
            indices_per_volume.append("all")
        else:
            indices_per_volume.append(np.asarray(v_args[1:], dtype=int))

    if used_indices_all and args.out_labels_ids:
        parser.error("'all' indices cannot be used with --out_labels_ids.")

    if args.merge_groups and (args.group_in_m or args.unique):
        parser.error("Cannot use --unique and --group_in_m with "
                     "--merge_groups.")

    # Check inputs / output
    assert_inputs_exist(parser, image_files)
    assert_outputs_exist(parser, args, args.output)

    # Load volume and do checks
    data_list = []
    first_img = nib.load(image_files[0])
    for i in range(len(image_files)):
        # Load images
        volume_nib = nib.load(image_files[i])
        data = get_data_as_label(volume_nib)
        data_list.append(data)
        assert (is_header_compatible(first_img, image_files[i]))

        if (isinstance(indices_per_volume[i], str)
                and indices_per_volume[i] == "all"):
            indices_per_volume[i] = np.unique(data)

    filtered_ids_per_vol = []
    # Remove background labels
    for id_list in indices_per_volume:
        id_list = np.asarray(id_list)
        new_ids = id_list[~np.in1d(id_list, args.background)]
        filtered_ids_per_vol.append(new_ids)

    # Prepare output indices
    if args.out_labels_ids:
        out_labels = args.out_labels_ids
        if not args.merge_groups \
                and len(out_labels) != len(np.hstack(indices_per_volume)):
            parser.error("--out_labels_ids, requires the same amount"
                         " of total given input indices.")
        elif len(out_labels) != len(args.volume_ids):
            parser.error("--out_labels_ids, requires the same amount"
                         " of total given groups (to merge).")
    elif args.unique:
        stack = np.hstack(filtered_ids_per_vol)
        ids = np.arange(len(stack) + 1)
        out_labels = np.setdiff1d(ids, args.background)[:len(stack)]
    elif args.group_in_m:
        m_list = []
        for i in range(len(filtered_ids_per_vol)):
            prefix = i * 10000
            m_list.append(prefix + np.asarray(filtered_ids_per_vol[i]))
        out_labels = np.hstack(m_list)
    else:
        if args.merge_groups:
            out_labels = np.arange(len(args.volume_ids))+1
        else:
            out_labels = np.hstack(filtered_ids_per_vol)

    if len(np.unique(out_labels)) != len(out_labels):
        logging.error("The same output label number was used "
                      "for multiple inputs")

    # Create the resulting volume
    current_id = 0
    resulting_labels = (np.ones_like(data_list[0], dtype=np.uint16)
                        * args.background)
    for i in range(len(image_files)):
        # Add given labels for each volume
        for index in filtered_ids_per_vol[i]:
            if args.merge_groups:
                for j, curr_volume_ids in enumerate(args.volume_ids):
                    if str(index) in curr_volume_ids[1:]:
                        where_at = j
                mask = data_list[i] == index
                resulting_labels[mask] = out_labels[where_at]
            else:
                mask = data_list[i] == index
                resulting_labels[mask] = out_labels[current_id]
                current_id += 1

                if np.count_nonzero(mask) == 0:
                    logging.warning(
                        "Label {} was not in the volume".format(index))

    # Save final combined volume
    nib.save(nib.Nifti1Image(resulting_labels, first_img.affine,
                             header=first_img.header),
             args.output)
Beispiel #14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_file, optional=args.mask)
    assert_outputs_exist(parser, args, args.out_file)

    if args.nbr_processes is None:
        args.nbr_processes = -1

    # load volume
    volume_nib = nib.load(args.in_file)
    data = get_data_as_label(volume_nib)
    vox_size = np.reshape(volume_nib.header.get_zooms(), (1, 3))
    img_shape = data.shape

    # Check if in both: label_to_fill & not_to_fill
    fill_and_not = np.in1d(args.label_not_to_dilate, args.label_to_fill)
    if np.any(fill_and_not):
        logging.error("Error, both in not_to_dilate and to_fill: {}".format(
                      np.asarray(args.label_not_to_dilate)[fill_and_not]))

    # Create background mask
    is_background_mask = np.zeros(img_shape, dtype=bool)
    for i in args.label_to_fill:
        is_background_mask = np.logical_or(is_background_mask, data == i)

    # Create not_to_dilate mask (initialized to background)
    not_to_dilate = np.copy(is_background_mask)
    for i in args.label_not_to_dilate:

        not_to_dilate = np.logical_or(not_to_dilate, data == i)

    # Add mask
    if args.mask:
        mask_nib = nib.load(args.mask)
        mask_data = get_data_as_mask(mask_nib)
        to_dilate_mask = np.logical_and(is_background_mask, mask_data)
    else:
        to_dilate_mask = is_background_mask

    # Create label mask
    is_label_mask = ~not_to_dilate

    if args.label_to_dilate is not None:
        # Check if in both: to_dilate & not_to_dilate
        dil_and_not = np.in1d(args.label_to_dilate, args.label_not_to_dilate)
        if np.any(dil_and_not):
            logging.error("Error, both in dilate and Not to dilate: {}".format(
                          np.asarray(args.label_to_dilate)[dil_and_not]))

        # Check if in both: to_dilate & to_fill
        dil_and_fill = np.in1d(args.label_to_dilate, args.label_to_fill)
        if np.any(dil_and_fill):
            logging.error("Error, both in dilate and to fill: {}".format(
                          np.asarray(args.label_to_dilate)[dil_and_fill]))

        # Create new label to dilate list
        new_label_mask = np.zeros_like(data, dtype=bool)
        for i in args.label_to_dilate:
            new_label_mask = np.logical_or(new_label_mask, data == i)

        # Combine both new_label_mask and not_to_dilate
        is_label_mask = np.logical_and(new_label_mask, ~not_to_dilate)

    # Get the list of indices
    background_pos = np.argwhere(to_dilate_mask) * vox_size
    label_pos = np.argwhere(is_label_mask) * vox_size
    ckd_tree = cKDTree(label_pos)

    # Compute the nearest labels for each voxel of the background
    dist, indices = ckd_tree.query(
        background_pos, k=1, distance_upper_bound=args.distance,
        n_jobs=args.nbr_processes)

    # Associate indices to the nearest label (in distance)
    valid_nearest = np.squeeze(np.isfinite(dist))
    id_background = np.flatnonzero(to_dilate_mask)[valid_nearest]
    id_label = np.flatnonzero(is_label_mask)[indices[valid_nearest]]

    # Change values of those background
    data = data.flatten()
    data[id_background.T] = data[id_label.T]
    data = data.reshape(img_shape)

    # Save image
    nib.save(nib.Nifti1Image(data.astype(np.uint16), volume_nib.affine,
                             header=volume_nib.header),
             args.out_file)
Beispiel #15
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_inputs_exist(parser, args.in_wmparc)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_path,
                                       create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than or equal to 0')
    if args.ctx_dilation_radius < 0:
        parser.error(
            'Cortex dilation radius "{}" '.format(args.ctx_dilation_radius) +
            'must be greater than 0')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    img_wmparc = nib.load(args.in_wmparc)
    if not is_header_compatible(img_wmparc, sft):
        parser.error('Headers from the tractogram and the wmparc are '
                     'not compatible.')
    if args.csf_bin:
        img_csf = nib.load(args.csf_bin)
        if not is_header_compatible(img_csf, sft):
            parser.error('Headers from the tractogram and the CSF mask are '
                         'not compatible.')

    if args.minL == 0 and np.isinf(args.maxL):
        logging.debug("You have not specified minL nor maxL. Output will "
                      "not be filtered according to length!")
    if np.isinf(args.angle):
        logging.debug("You have not specified the angle. Loops will "
                      "not be filtered!")
    if args.ctx_dilation_radius == 0:
        logging.debug("You have not specified the cortex dilation radius. "
                      "The wmparc atlas will not be dilated!")

    o_dict = {}
    step_dict = ['length', 'no_loops', 'no_end_csf', 'end_in_atlas']
    wm_labels = load_wmparc_labels()

    in_sft_name = os.path.splitext(os.path.basename(args.in_tractogram))[0]
    out_sft_rootname = in_sft_name + "_filtered"
    _, ext = os.path.splitext(args.in_tractogram)
    out_sft_name = os.path.join(args.out_path,
                                out_sft_rootname + "_filtered" + ext)

    # STEP 1 - Filter length
    step = step_dict[0]
    steps_combined = step

    new_sft = filter_streamlines_by_length(sft, args.minL, args.maxL)

    # Streamline count before and after filtering lengths
    o_dict[in_sft_name + ext] =\
        dict({'streamline_count': len(sft.streamlines)})
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')
        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 2 - Filter loops
    step = step_dict[1]
    steps_combined += "_" + step

    ids_c = remove_loops_and_sharp_turns(sft.streamlines, args.angle)
    new_sft = filter_tractogram_data(sft, ids_c)

    # Streamline count after filtering loops
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 3 - Filter CSF endings
    step = step_dict[2]
    steps_combined += "_" + step

    # Mask creation
    if args.csf_bin:
        mask = get_data_as_mask(img_csf)
    else:
        atlas = get_data_as_label(img_wmparc)
        mask = binarize_labels(atlas, wm_labels["csf_labels"])

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, mask, 'any', True)

    # Streamline count after filtering CSF endings
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        if not args.csf_bin:
            nib.save(
                nib.Nifti1Image(mask, img_wmparc.affine, img_wmparc.header),
                os.path.join(new_path, 'csf_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 4 - Filter WM endings
    step = step_dict[3]
    steps_combined += "_" + step

    # Mask creation
    ctx_fs_labels = wm_labels["ctx_lh_fs_labels"] + \
        wm_labels["ctx_rh_fs_labels"]
    vox_size = np.reshape(img_wmparc.header.get_zooms(), (1, 3))
    atlas_wm = get_data_as_label(img_wmparc)
    atlas_shape = atlas_wm.shape
    wmparc_ctx = binarize_labels(atlas_wm, ctx_fs_labels)
    wmparc_nuclei = binarize_labels(atlas_wm, wm_labels["nuclei_fs_labels"])

    # Dilation of cortex
    if args.ctx_dilation_radius:
        ctx_mask = dilate_mask(wmparc_ctx, atlas_shape, vox_size,
                               args.ctx_dilation_radius)
    else:
        ctx_mask = wmparc_ctx

    freesurfer_mask = np.zeros(atlas_shape, dtype=np.uint16)
    freesurfer_mask[np.logical_or(wmparc_nuclei, ctx_mask)] = 1

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, freesurfer_mask, 'both_ends', False)

    # Streamline count after final filtering
    o_dict[out_sft_rootname + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        nib.save(
            nib.Nifti1Image(freesurfer_mask, img_wmparc.affine,
                            img_wmparc.header),
            os.path.join(new_path, 'atlas_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    # Finish filtering
    if args.verbose:
        display_count(o_dict, args.indent, args.sort_keys)

    if args.save_counts:
        save_count(o_dict, args.out_path, args.indent, args.sort_keys)

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")
            return
        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

    sft = new_sft
    save_tractogram(sft, out_sft_name)
Beispiel #16
0
#!/usr/bin/env python

import sys

import numpy as np
import nibabel as nib
from scilpy.io.image import get_data_as_label

img_labels = nib.load(sys.argv[1])
data_labels = get_data_as_label(img_labels)
real_labels = np.unique(data_labels)[1:]
np.savetxt(sys.argv[2], real_labels, fmt='%i')