def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    hdf5_file = h5py.File(args.in_hdf5, 'r')
    for key in hdf5_file.keys():
        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        sft = StatefulTractogram(streamlines,
                                 header,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        if args.include_dps:
            for dps_key in hdf5_file[key].keys():
                if dps_key not in ['data', 'offsets', 'lengths']:
                    sft.data_per_streamline[dps_key] = hdf5_file[key][dps_key]

        save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key)))

    hdf5_file.close()
Esempio n. 2
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Esempio n. 3
0
def test_slr_flow():
    with TemporaryDirectory() as out_dir:
        data_path = get_fnames('fornix')

        fornix = load_tractogram(data_path, 'same',
                                 bbox_valid_check=False).streamlines

        f = Streamlines(fornix)
        f1 = f.copy()

        f1_path = pjoin(out_dir, "f1.trk")
        sft = StatefulTractogram(f1, data_path, Space.RASMM)
        save_tractogram(sft, f1_path, bbox_valid_check=False)

        f2 = f1.copy()
        f2._data += np.array([50, 0, 0])

        f2_path = pjoin(out_dir, "f2.trk")
        sft = StatefulTractogram(f2, data_path, Space.RASMM)
        save_tractogram(sft, f2_path, bbox_valid_check=False)

        slr_flow = SlrWithQbxFlow(force=True)
        slr_flow.run(f1_path, f2_path)

        out_path = slr_flow.last_generated_outputs['out_moved']

        npt.assert_equal(os.path.isfile(out_path), True)
Esempio n. 4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_bundle, args.remaining_bundle)
    if args.alpha <= 0 or args.alpha > 1:
        parser.error('--alpha should be ]0, 1]')

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    if len(sft) == 0:
        logging.warning("Bundle file contains no streamline")
        return

    check_tracts_same_format(
        parser, [args.in_bundle, args.out_bundle, args.remaining_bundle])
    outliers, inliers = remove_outliers(sft.streamlines, args.alpha)

    inliers_sft = sft[inliers]
    outliers_sfts = sft[outliers]

    if len(inliers) == 0:
        logging.warning("All streamlines are considered outliers."
                        "Please lower the --alpha parameter")
    else:
        save_tractogram(inliers_sft, args.out_bundle)

    if len(outliers) == 0:
        logging.warning("No outlier found. Please raise the --alpha parameter")
    elif args.remaining_bundle:
        save_tractogram(outliers_sfts, args.remaining_bundle)
def main():
    """Parse arguments, extract tractograms and save them on disk."""
    args = parse_args()

    with h5py.File(args.dataset, 'r') as hdf_file:
        if args.subject_ids:
            subject_ids = args.subject_ids
        else:
            subject_ids = list(hdf_file.keys())

        for subject_id in subject_ids:
            data = np.array(hdf_file[subject_id]['streamlines/data'])
            offsets = np.array(hdf_file[subject_id]['streamlines/offsets'])
            lengths = np.array(hdf_file[subject_id]['streamlines/lengths'])

            vox2rasmm_affine = np.array(hdf_file[subject_id]['input_volume'].attrs['vox2rasmm'])
            img = nib.Nifti1Image(np.array(hdf_file[subject_id]['input_volume/data']), vox2rasmm_affine)

            streamlines_vox = nib.streamlines.ArraySequence()
            streamlines_vox._data = data
            streamlines_vox._offsets = offsets
            streamlines_vox._lengths = lengths

            tractogram = StatefulTractogram(streamlines_vox, img,
                                            space=Space.VOX, shifted_origin=True)

            # Save tractogram
            fname = "{}_{}.tck".format(pathlib.Path(args.dataset).stem, subject_id)
            if args.prefix:
                fname = "{}_{}".format(args.prefix, fname)
            save_tractogram(tractogram, fname)
def main():
    args = parse_args()

    if os.path.exists(args.output) and not args.force:
        raise ValueError("Output already exists! Use --force to overwrite.")

    sft = load_tractogram(args.input, args.ref,
                          to_space=Space.RASMM,
                          trk_header_check=False,
                          bbox_valid_check=False)

    # There may be invalid streamlines in the input tractogram
    sft.remove_invalid_streamlines()

    # Work in voxel space, and move streamlines to corner so we can use floor()
    # to compare with valid voxel coordinates.
    sft.to_vox()
    sft.to_corner()

    mask = nib.load(args.mask)

    valid_voxels = np.where(mask.get_fdata() > 0.5)

    valid_streamlines = []

    for s in sft.streamlines:
        if _is_coords_valid(s[0], valid_voxels):
            valid_streamlines.append(s)
        elif _is_coords_valid(s[-1], valid_voxels):
            valid_streamlines.append(s[::-1])

    valid_sft = StatefulTractogram(valid_streamlines, args.ref, space=sft.space, shifted_origin=sft.shifted_origin)
    save_tractogram(valid_sft, args.output)
Esempio n. 7
0
def export_bundles(subses_dict, clean_bundles_file, bundles_file, bundle_dict,
                   tracking_params, segmentation_params):
    img = nib.load(subses_dict['dwi_file'])
    for this_bundles_file, folder in zip([clean_bundles_file, bundles_file],
                                         ['clean_bundles', 'bundles']):
        bundles_dir = op.join(subses_dict['results_dir'], folder)
        os.makedirs(bundles_dir, exist_ok=True)
        trk = nib.streamlines.load(this_bundles_file)
        tg = trk.tractogram
        streamlines = tg.streamlines
        for bundle in bundle_dict:
            if bundle != "whole_brain":
                uid = bundle_dict[bundle]['uid']
                idx = np.where(tg.data_per_streamline['bundle'] == uid)[0]
                this_sl = dtu.transform_tracking_output(
                    streamlines[idx], np.linalg.inv(img.affine))

                this_tgm = StatefulTractogram(this_sl, img, Space.VOX)
                fname = op.split(
                    get_fname(subses_dict, f'-{bundle}'
                              f'_tractography.trk',
                              tracking_params=tracking_params,
                              segmentation_params=segmentation_params))
                fname = op.join(bundles_dir, fname[1])
                logger.info(f"Saving {fname}")
                save_tractogram(this_tgm, fname, bbox_valid_check=False)
                meta = dict(source=this_bundles_file)
                meta_fname = fname.split('.')[0] + '.json'
                afd.write_json(meta_fname, meta)
    return True
def _load_directly_and_verify(batch_loader, batch_idx_tuples):
    expected_nb_streamlines = 0
    for s, idx in batch_idx_tuples:
        expected_nb_streamlines += len(idx)

    # Saving input coordinates as mask. You can "open the mask and verify that
    # they fit the streamlines.
    batch_streamlines, ids, inputs_tuple = batch_loader.load_batch(
        batch_idx_tuples, save_batch_input_mask=True)
    batch_input_masks, batch_inputs = inputs_tuple
    filename = os.path.join(str(tmp_dir), 'test_batch1_underlying_mask_' +
                            now_s + '.nii.gz')
    logging.info("Saving subj 0's underlying coords mask to {}"
                 .format(filename))
    mask = batch_input_masks[0]
    ref_img = nib.load(ref)
    data_nii = nib.Nifti1Image(np.asarray(mask, dtype=bool), ref_img)
    nib.save(data_nii, filename)

    # Save the last batch's SFT.
    logging.info("Saving subj 0's tractogram {}"
                 .format('test_batch1_' + now_s))

    sft = StatefulTractogram(batch_streamlines, reference=ref, space=Space.VOX,
                             origin=Origin.TRACKVIS)
    filename = os.path.join(str(tmp_dir), 'test_batch_reverse_split_' +
                            now_s + '.trk')
    save_tractogram(sft, filename)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram, args.reference)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    ori_len = len(sft)
    sft.remove_invalid_streamlines()

    indices = []
    if args.remove_single_point:
        # Will try to do a PR in Dipy
        indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1]

    if args.remove_overlapping_points:
        for i in range(len(sft)):
            norm = np.linalg.norm(np.gradient(sft.streamlines[i], axis=0),
                                  axis=1)
            if (norm < 0.001).any():
                indices.append(i)

    indices = np.setdiff1d(range(len(sft)), indices)
    new_sft = StatefulTractogram.from_sft(
        sft.streamlines[indices],
        sft,
        data_per_point=sft.data_per_point[indices],
        data_per_streamline=sft.data_per_streamline[indices])
    logging.warning('Removed {} invalid streamlines.'.format(ori_len -
                                                             len(new_sft)))
    save_tractogram(new_sft, args.out_tractogram)
Esempio n. 10
0
def io_tractogram(extension):
    with InTemporaryDirectory():
        fname = 'test.{}'.format(extension)

        in_affine = np.eye(4)
        in_dimensions = np.array([50, 50, 50])
        in_voxel_sizes = np.array([2, 1.5, 1.5])
        nii_header = create_nifti_header(in_affine, in_dimensions,
                                         in_voxel_sizes)
        sft = StatefulTractogram(streamlines, nii_header, space=Space.RASMM)

        save_tractogram(sft, fname, bbox_valid_check=False)

        if extension == 'trk':
            reference = 'same'
        else:
            reference = nii_header

        sft = load_tractogram(fname, reference, bbox_valid_check=False)
        affine, dimensions, voxel_sizes, _ = sft.space_attribute

        npt.assert_array_equal(in_affine, affine)
        npt.assert_array_equal(in_voxel_sizes, voxel_sizes)
        npt.assert_array_equal(in_dimensions, dimensions)
        npt.assert_equal(len(sft), len(streamlines))
        npt.assert_array_almost_equal(sft.streamlines[1], streamline,
                                      decimal=4)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not args.out_tractogram.endswith('.trk'):
        parser.error('Output file needs to end with .trk.')

    if len(args.color) == 7:
        args.color = '0x' + args.color.lstrip('#')

    if len(args.color) == 8:
        color_int = int(args.color, 0)
        red = color_int >> 16
        green = (color_int & 0x00FF00) >> 8
        blue = color_int & 0x0000FF
    else:
        parser.error('Hexadecimal RGB color should be formatted as "#RRGGBB"'
                     ' or 0xRRGGBB.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    sft.data_per_point["color"] = [np.tile([red, green, blue],
                                           (len(i), 1)) for i in sft.streamlines]

    sft = StatefulTractogram.from_sft(sft.streamlines, sft,
                                      data_per_point=sft.data_per_point)

    save_tractogram(sft, args.out_tractogram)
Esempio n. 12
0
    def _export_bundles(self, row):
        odf_model = self.tracking_params['odf_model']
        directions = self.tracking_params['directions']
        seg_algo = self.segmentation_params['seg_algo']

        for func, folder in zip([self._clean_bundles, self._segment],
                                ['clean_bundles', 'bundles']):
            bundles_file = func(row)

            bundles_dir = op.join(row['results_dir'], folder)
            os.makedirs(bundles_dir, exist_ok=True)
            trk = nib.streamlines.load(bundles_file)
            tg = trk.tractogram
            streamlines = tg.streamlines
            for bundle in self.bundle_dict:
                if bundle != "whole_brain":
                    uid = self.bundle_dict[bundle]['uid']
                    idx = np.where(tg.data_per_streamline['bundle'] == uid)[0]
                    this_sl = dtu.transform_tracking_output(
                        streamlines[idx], np.linalg.inv(row['dwi_affine']))

                    this_tgm = StatefulTractogram(this_sl, row['dwi_img'],
                                                  Space.VOX)

                    fname = op.split(
                        self._get_fname(
                            row, f'_space-RASMM_model-{odf_model}_desc-'
                            f'{directions}-{seg_algo}-{bundle}'
                            f'_tractography.trk'))
                    fname = op.join(fname[0], bundles_dir, fname[1])
                    save_tractogram(this_tgm, fname, bbox_valid_check=False)
                    meta = dict(source=bundles_file)
                    meta_fname = fname.split('.')[0] + '.json'
                    afd.write_json(meta_fname, meta)
Esempio n. 13
0
def _save_if_needed(sft, hdf5_file, args, save_type, step_type, in_label,
                    out_label):
    if step_type == 'final':
        group = hdf5_file.create_group('{}_{}'.format(in_label, out_label))
        group.create_dataset('data',
                             data=sft.streamlines.get_data(),
                             dtype=np.float32)
        group.create_dataset('offsets',
                             data=sft.streamlines._offsets,
                             dtype=np.int64)
        group.create_dataset('lengths',
                             data=sft.streamlines._lengths,
                             dtype=np.int32)
        for key in sft.data_per_streamline.keys():
            group.create_dataset(key,
                                 data=sft.data_per_streamline[key],
                                 dtype=np.float32)

    if args.out_dir:
        saving_options = _get_saving_options(args)
        out_paths = _get_output_paths(args)

        if saving_options[save_type] and len(sft):
            out_name = os.path.join(out_paths[step_type],
                                    '{}_{}.trk'.format(in_label, out_label))
            save_tractogram(sft, out_name)
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    new_streamlines, new_per_point, new_per_streamline = filter_streamlines_by_length(
        sft, args.minL, args.maxL)

    new_sft = StatefulTractogram(new_streamlines,
                                 sft,
                                 Space.RASMM,
                                 data_per_streamline=new_per_streamline,
                                 data_per_point=new_per_point)

    if not new_streamlines:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(new_sft, args.out_tractogram)
Esempio n. 15
0
def test_io_streamline():
    with InTemporaryDirectory():
        fname = 'test.trk'
        affine = np.eye(4)

        # Test save
        save_tractogram(fname, streamlines, affine,
                        vox_size=np.array([2, 1.5, 1.5]),
                        shape=np.array([50, 50, 50]))
        tfile = nib.streamlines.load(fname)
        npt.assert_array_equal(affine, tfile.affine)
        npt.assert_array_equal(np.array([2, 1.5, 1.5]),
                               tfile.header.get('voxel_sizes'))
        npt.assert_array_equal(np.array([50, 50, 50]),
                               tfile.header.get('dimensions'))
        npt.assert_equal(len(tfile.streamlines), len(streamlines))
        npt.assert_array_almost_equal(tfile.streamlines[1], streamline,
                                      decimal=4)

        # Test basic save
        save_tractogram(fname, streamlines, affine)
        tfile = nib.streamlines.load(fname)
        npt.assert_array_equal(affine, tfile.affine)
        npt.assert_equal(len(tfile.streamlines), len(streamlines))
        npt.assert_array_almost_equal(tfile.streamlines[1], streamline,
                                      decimal=5)

        # Test Load
        local_streamlines, hdr = load_tractogram(fname)
        npt.assert_equal(len(local_streamlines), len(streamlines))
        for arr1, arr2 in zip(local_streamlines, streamlines):
            npt.assert_allclose(arr1, arr2)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_transfo])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        log_level = logging.INFO
        logging.basicConfig(level=log_level)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    transfo = load_matrix_in_any_format(args.in_transfo)
    if args.inverse:
        transfo = np.linalg.inv(load_matrix_in_any_format(args.in_transfo))

    before, after = compute_distance_barycenters(wb_file, model_file, transfo)
    if after > before:
        logging.warning('The distance between volumes barycenter should be '
                        'lower after registration. Maybe try using/removing '
                        '--inverse.')
        logging.info('Distance before: {}, Distance after: {}'.format(
            np.round(before, 3), np.round(after, 3)))
    model_streamlines = transform_streamlines(model_file.streamlines, transfo)

    rng = np.random.RandomState(args.seed)
    if args.in_pickle:
        with open(args.in_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.tractogram_clustering_thr,
                               rng=rng,
                               verbose=args.verbose)

    if args.out_pickle:
        with open(args.out_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(ArraySequence(model_streamlines),
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file.space_attributes,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser,
        [args.in_moving_tractogram, args.in_target_file, args.in_transfo],
        args.in_deformation)
    assert_outputs_exist(parser, args, args.out_tractogram)

    moving_sft = load_tractogram_with_reference(parser,
                                                args,
                                                args.in_moving_tractogram,
                                                bbox_check=False)

    transfo = load_matrix_in_any_format(args.in_transfo)
    deformation_data = None
    if args.in_deformation is not None:
        deformation_data = np.squeeze(
            nib.load(args.in_deformation).get_fdata(dtype=np.float32))

    new_sft = transform_warp_streamlines(moving_sft,
                                         transfo,
                                         args.in_target_file,
                                         inverse=args.inverse,
                                         deformation_data=deformation_data,
                                         remove_invalid=args.remove_invalid,
                                         cut_invalid=args.cut_invalid)

    if args.keep_invalid:
        if not new_sft.is_bbox_in_vox_valid():
            logging.warning('Saving tractogram with invalid streamlines.')
        save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)
    else:
        save_tractogram(new_sft, args.out_tractogram)
Esempio n. 18
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    smoothed_streamlines = []
    for streamline in sft.streamlines:
        if args.gaussian:
            tmp_streamlines = smooth_line_gaussian(streamline, args.gaussian)
        else:
            tmp_streamlines = smooth_line_spline(streamline, args.spline[0],
                                                 args.spline[1])

        if args.error_rate:
            smoothed_streamlines.append(
                compress_streamlines(tmp_streamlines, args.error_rate))

    smoothed_sft = StatefulTractogram.from_sft(
        smoothed_streamlines, sft, data_per_streamline=sft.data_per_streamline)
    save_tractogram(smoothed_sft, args.out_tractogram)
Esempio n. 19
0
def convert_trk(tractogram, outtype='tck', output=None, force=False):
    ''' Convert presumed trk file to either a tck file, or a vtk file
    '''
    # figure out input type, load file

    if nib.streamlines.detect_format(tractogram) is not nib.streamlines.TrkFile:
        print("Skipping non TRK file: '{}'".format(tractogram))
        return

    if output:
        output_filename = output
    else:
        output_filename = tractogram[:-4] + '.{}'.format(outtype)
        if os.path.isfile(output_filename) and not force:
            print("Skipping existing file: '{}'. Use -f to overwrite.".format(output_filename))
            return

    print("Converting file: {}\n".format(output_filename))
    # load tractogram, set origin to the corner
    trk = load_tractogram(tractogram, reference='same')
    trk.to_corner()  # set origin to the corner
    if  outtype == 'tck':
        save_tractogram(trk, output_filename)
    else:
        dipy.io.vtk.save_vtk_streamlines(trk, output_filename)
Esempio n. 20
0
        def wrapper_as_file(*args, **kwargs):
            subses_dict = get_args(
                func, ["subses_dict"], args)[0]
            if include_track:
                tracking_params = get_args(
                    func, ["tracking_params"], args)[0]
            else:
                tracking_params = None
            if include_seg:
                segmentation_params = get_args(
                    func, ["segmentation_params"], args)[0]
            else:
                segmentation_params = None
            this_file = get_fname(
                subses_dict, suffix,
                tracking_params=tracking_params,
                segmentation_params=segmentation_params)
            if not op.exists(this_file):
                img_trk_or_csv, meta = func(*args, **kwargs)

                logger.info(f"Saving {this_file}")
                if isinstance(img_trk_or_csv, nib.Nifti1Image):
                    nib.save(img_trk_or_csv, this_file)
                elif isinstance(img_trk_or_csv, StatefulTractogram):
                    save_tractogram(
                        img_trk_or_csv, this_file, bbox_valid_check=False)
                else:
                    img_trk_or_csv.to_csv(this_file)
                meta_fname = get_fname(
                    subses_dict, suffix.split('.')[0] + '.json',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params)
                afd.write_json(meta_fname, meta)
            return this_file
def random_streamline_color():
    np.random.seed(0)
    sft = load_tractogram(filepath_dix['gs.tck'], filepath_dix['gs.nii'])

    uniform_colors_x = np.random.randint(0, 255, (13, 1))
    uniform_colors_y = np.random.randint(0, 255, (13, 1))
    uniform_colors_z = np.random.randint(0, 255, (13, 1))
    uniform_colors_x = np.expand_dims(
        np.repeat(uniform_colors_x, 8, axis=1), axis=-1)
    uniform_colors_y = np.expand_dims(
        np.repeat(uniform_colors_y, 8, axis=1), axis=-1)
    uniform_colors_z = np.expand_dims(
        np.repeat(uniform_colors_z, 8, axis=1), axis=-1)

    coloring_dict = {}
    coloring_dict['color_x'] = uniform_colors_x
    coloring_dict['color_y'] = uniform_colors_y
    coloring_dict['color_z'] = uniform_colors_z

    try:
        sft.data_per_point = coloring_dict
        with InTemporaryDirectory():
            save_tractogram(sft, 'random_streamlines_color.trk')
        return True
    except (TypeError, ValueError):
        return False
Esempio n. 22
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, [], optional=args.output_centroids)
    if args.output_clusters_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.output_clusters_dir,
                                           create_dir=True)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines = sft.streamlines
    thresholds = [40, 30, 20, args.dist_thresh]
    clusters = qbx_and_merge(streamlines,
                             thresholds,
                             nb_pts=args.nb_points,
                             verbose=False)

    for i, cluster in enumerate(clusters):
        if len(cluster.indices) > 1:
            cluster_streamlines = itemgetter(*cluster.indices)(streamlines)
        else:
            cluster_streamlines = streamlines[cluster.indices]

        new_sft = StatefulTractogram(cluster_streamlines, sft, Space.RASMM)
        save_tractogram(
            new_sft,
            os.path.join(args.output_clusters_dir, 'cluster_{}.trk'.format(i)))

    if args.output_centroids:
        new_sft = StatefulTractogram(clusters.centroids, sft, Space.RASMM)
        save_tractogram(new_sft, args.output_centroids)
Esempio n. 23
0
def save_fg(fiber_groups, img, output_dir, bname):
    for kk in fiber_groups:
        print(kk, len(fiber_groups[kk]["sl"].streamlines))
        sft = StatefulTractogram(fiber_groups[kk]["sl"].streamlines, img,
                                 Space.RASMM)
        save_tractogram(sft,
                        op.join(output_dir, f"{bname}_{kk}_reco.trk"),
                        bbox_valid_check=False)
Esempio n. 24
0
def test_recobundles_flow():
    with TemporaryDirectory() as out_dir:
        data_path = get_fnames('fornix')

        fornix = load_tractogram(data_path, 'same',
                                 bbox_valid_check=False).streamlines

        f = Streamlines(fornix)
        f1 = f.copy()

        f2 = f1[:15].copy()
        f2._data += np.array([40, 0, 0])

        f.extend(f2)

        f2_path = pjoin(out_dir, "f2.trk")
        sft = StatefulTractogram(f2, data_path, Space.RASMM)
        save_tractogram(sft, f2_path, bbox_valid_check=False)

        f1_path = pjoin(out_dir, "f1.trk")
        sft = StatefulTractogram(f, data_path, Space.RASMM)
        save_tractogram(sft, f1_path, bbox_valid_check=False)

        rb_flow = RecoBundlesFlow(force=True)
        rb_flow.run(f1_path,
                    f2_path,
                    greater_than=0,
                    clust_thr=10,
                    model_clust_thr=5.,
                    reduction_thr=10,
                    out_dir=out_dir)

        labels = rb_flow.last_generated_outputs['out_recognized_labels']
        recog_trk = rb_flow.last_generated_outputs['out_recognized_transf']

        rec_bundle = load_tractogram(recog_trk, 'same',
                                     bbox_valid_check=False).streamlines
        npt.assert_equal(len(rec_bundle) == len(f2), True)

        label_flow = LabelsBundlesFlow(force=True)
        label_flow.run(f1_path, labels)

        recog_bundle = label_flow.last_generated_outputs['out_bundle']
        rec_bundle_org = load_tractogram(recog_bundle,
                                         'same',
                                         bbox_valid_check=False).streamlines

        BMD = BundleMinDistanceMetric()
        nb_pts = 20
        static = set_number_of_points(f2, nb_pts)
        moving = set_number_of_points(rec_bundle_org, nb_pts)

        BMD.setup(static, moving)
        x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])  # affine
        bmd_value = BMD.distance(x0.tolist())

        npt.assert_equal(bmd_value < 1, True)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram,
                         optional=args.looping_tractogram)
    check_tracts_same_format(parser, [args.in_tractogram, args.out_tractogram,
                                      args.looping_tractogram])

    if args.threshold <= 0:
        parser.error('Threshold "{}" '.format(args.threshold) +
                     'must be greater than 0')

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than 0')

    tractogram = load_tractogram_with_reference(
        parser, args, args.in_tractogram)

    streamlines = tractogram.streamlines

    ids_c = []

    ids_l = []

    if len(streamlines) > 1:
        ids_c = remove_loops_and_sharp_turns(
            streamlines, args.angle, use_qb=args.qb,
            qb_threshold=args.threshold)
        ids_l = np.setdiff1d(np.arange(len(streamlines)), ids_c)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(ids_c) > 0:
        sft_c = filter_tractogram_data(tractogram, ids_c)
        save_tractogram(sft_c, args.out_tractogram)
    else:
        logging.warning(
            'No clean streamlines in {}'.format(args.in_tractogram))

    if args.display_counts:
        sc_bf = len(tractogram.streamlines)
        sc_af = len(sft_c.streamlines)
        print(json.dumps({'streamline_count_before_filtering': int(sc_bf),
                         'streamline_count_after_filtering': int(sc_af)},
                         indent=args.indent))

    if len(ids_l) == 0:
        logging.warning('No loops in {}'.format(args.in_tractogram))
    elif args.looping_tractogram:
        sft_l = filter_tractogram_data(tractogram, ids_l)
        save_tractogram(sft_l, args.looping_tractogram)
Esempio n. 26
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)
    if args.save_empty and args.labels_list is None:
        parser.error("The option --save_empty requires --labels_list.")

    with h5py.File(args.in_hdf5, 'r') as hdf5_file:
        if args.save_empty:
            all_labels = np.loadtxt(args.labels_list, dtype='str')
            comb_list = list(itertools.combinations(all_labels, r=2))
            comb_list.extend(zip(all_labels, all_labels))
            keys = [i[0] + '_' + i[1] for i in comb_list]
        else:
            keys = hdf5_file.keys()

        if args.edge_keys is not None:
            selected_keys = [key for key in keys if key in args.edge_keys]
        elif args.node_keys is not None:
            selected_keys = []
            for node in args.node_keys:
                selected_keys.extend([
                    key for key in keys
                    if key.startswith(node + '_') or key.endswith('_' + node)
                ])
        else:
            selected_keys = keys

        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        for key in selected_keys:
            streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

            if len(streamlines) == 0 and not args.save_empty:
                continue

            sft = StatefulTractogram(streamlines,
                                     header,
                                     Space.VOX,
                                     origin=Origin.TRACKVIS)
            if args.include_dps:
                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        sft.data_per_streamline[dps_key] = hdf5_file[key][
                            dps_key]

            save_tractogram(sft,
                            '{}.trk'.format(os.path.join(args.out_dir, key)))
Esempio n. 27
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser,
                         args,
                         args.out_tractogram,
                         optional=args.remaining_tractogram)
    check_tracts_same_format(
        parser,
        [args.in_tractogram, args.out_tractogram, args.remaining_tractogram])

    if not (-1 <= args.minU <= 1 and -1 <= args.maxU <= 1):
        parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
                     'must be between -1 and 1.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    ids_c = detect_ushape(sft, args.minU, args.maxU)
    ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)

    if len(ids_c) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))
            return

        logging.debug('The file {} contains 0 streamline.'.format(
            args.out_tractogram))

    save_tractogram(sft[ids_c], args.out_tractogram)

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(ids_c)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if args.remaining_tractogram:
        if len(ids_l) == 0:
            if args.no_empty:
                logging.debug("The file {} won't be written (0 streamline"
                              ").".format(args.remaining_tractogram))
                return

            logging.warning('No remaining streamlines.')

        save_tractogram(sft[ids_l], args.remaining_tractogram)
def save_tracts_from_voxel_space(tract_fname,
                                 ref_anat_fname,
                                 tracts,
                                 data_per_streamline=None,
                                 data_per_point=None):
    sft = StatefulTractogram(tracts,
                             ref_anat_fname,
                             Space.VOX,
                             data_per_streamline=data_per_streamline,
                             data_per_point=data_per_point)
    save_tractogram(sft, tract_fname, bbox_valid_check=False)
Esempio n. 29
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser,
                         args,
                         args.out_tractogram,
                         optional=args.remaining_tractogram)
    check_tracts_same_format(
        parser,
        [args.in_tractogram, args.out_tractogram, args.remaining_tractogram])

    if not (-1 <= args.minU <= 1 and -1 <= args.maxU <= 1):
        parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
                     'must be between -1 and 1.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    ids_c = []
    ids_l = []

    if len(sft.streamlines) > 1:
        ids_c = detect_ushape(sft, args.minU, args.maxU)
        ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)
    else:
        parser.error(
            'Zero or one streamline in {}'.format(args.in_tractogram) +
            '. The file must have more than one streamline.')

    if len(ids_c) > 0:
        save_tractogram(sft[ids_c], args.out_tractogram)
    else:
        logging.warning('No u-shape streamlines in {}'.format(
            args.in_tractogram))

    if args.display_counts:
        sc_bf = len(sft.streamlines)
        sc_af = len(ids_c)
        print(
            json.dumps(
                {
                    'streamline_count_before_filtering': int(sc_bf),
                    'streamline_count_after_filtering': int(sc_af)
                },
                indent=args.indent))

    if len(ids_l) == 0:
        logging.warning('No remaining streamlines '
                        'in {}'.format(args.remaining_tractogram))
    elif args.remaining_tractogram:
        save_tractogram(sft[ids_l], args.remaining_tractogram)
Esempio n. 30
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    sft.to_vox()
    sft.to_corner()

    new_sft = flip_sft(sft, args.axes)
    save_tractogram(new_sft, args.out_tractogram)
Esempio n. 31
0
def save_clusters(cluster_lists, indexes_list, directory, basenames_list):
    output_streamlines = []
    for idx in indexes_list:
        streamlines = cluster_lists[idx].streamlines
        output_streamlines.extend(streamlines)

        if directory:
            tmp_sft = StatefulTractogram(streamlines, cluster_lists[0],
                                         Space.RASMM)
            tmp_filename = os.path.join(directory, basenames_list[idx])
            save_tractogram(tmp_sft, tmp_filename, bbox_valid_check=False)

    return output_streamlines
Esempio n. 32
0
def test_horizon_flow():

    s1 = 10 * np.array([[0, 0, 0],
                        [1, 0, 0],
                        [2, 0, 0],
                        [3, 0, 0],
                        [4, 0, 0]], dtype='f8')

    s2 = 10 * np.array([[0, 0, 0],
                        [0, 1, 0],
                        [0, 2, 0],
                        [0, 3, 0],
                        [0, 4, 0]], dtype='f8')

    s3 = 10 * np.array([[0, 0, 0],
                        [1, 0.2, 0],
                        [2, 0.2, 0],
                        [3, 0.2, 0],
                        [4, 0.2, 0]], dtype='f8')

    print(s1.shape)
    print(s2.shape)
    print(s3.shape)

    streamlines = Streamlines()
    streamlines.append(s1)
    streamlines.append(s2)
    streamlines.append(s3)

    tractograms = [streamlines]
    images = None

    horizon(tractograms, images=images, cluster=True, cluster_thr=5,
            random_colors=False, length_lt=np.inf, length_gt=0,
            clusters_lt=np.inf, clusters_gt=0,
            world_coords=False, interactive=False)
#
    affine = np.diag([2., 1, 1, 1]).astype('f8')
#
    data = 255 * np.random.rand(150, 150, 150)
#
    images = [(data, affine)]

    horizon(tractograms, images=images, cluster=True, cluster_thr=5,
            random_colors=False, length_lt=np.inf, length_gt=0,
            clusters_lt=np.inf, clusters_gt=0,
            world_coords=True, interactive=False)

    with TemporaryDirectory() as out_dir:

        fimg = os.path.join(out_dir, 'test.nii.gz')
        ftrk = os.path.join(out_dir, 'test.trk')

        save_nifti(fimg, data, affine)
        save_tractogram(ftrk, streamlines, affine)

        input_files = [ftrk, fimg]

        npt.assert_equal(len(input_files), 2)

        hz_flow = HorizonFlow()

        hz_flow.run(input_files=input_files, stealth=True,
                    out_dir=out_dir, out_stealth_png='tmp_x.png')

        npt.assert_equal(os.path.exists(os.path.join(out_dir, 'tmp_x.png')),
                         True)