예제 #1
0
def flip_streamlines(tract_filename, ref_anat, out_filename, flip_x, flip_y,
                     flip_z, flip_mode):
    # Detect the format of the tracts file.
    tracts_format = tc.detect_format(tract_filename)
    tracts_file = tracts_format(tract_filename, anatFile=ref_anat)

    tracts = np.array([s for s in tracts_file])

    flip_vector = get_axis_flip_vector(flip_x, flip_y, flip_z)
    shift_vector = get_shift_vector(flip_mode, ref_anat, tracts)

    flipped_tracts = []

    for tract in tracts:
        mod_tract = tract + shift_vector
        mod_tract *= flip_vector
        mod_tract -= shift_vector
        flipped_tracts.append(mod_tract)

    out_hdr = tracts_file.hdr

    out_format = tc.detect_format(out_filename)
    out_tracts = out_format.create(out_filename, out_hdr, anatFile=ref_anat)

    out_tracts += flipped_tracts

    out_tracts.close()
예제 #2
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filenames = args.input
    out_filename = args.output
    #anat_filename = args.anat
    isForcing = args.isForce
    isVerbose = args.isVerbose

    if isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    for in_filename in in_filenames:
        if not os.path.isfile(in_filename):
            parser.error('"{0}" must be an existing file!'.format(in_filename))

        if not tractconverter.is_supported(in_filename):
            parser.error('Input file must be one of {0}!'.format(",".join(
                FORMATS.keys())))

    if not tractconverter.is_supported(out_filename):
        parser.error('Output file must be one of {0}!'.format(",".join(
            FORMATS.keys())))

    if os.path.isfile(out_filename):
        if isForcing:
            logging.info('Overwriting "{0}".'.format(out_filename))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(
                out_filename))

    inFormats = [
        tractconverter.detect_format(in_filename)
        for in_filename in in_filenames
    ]
    outFormat = tractconverter.detect_format(out_filename)

    # if anat_filename is not None:
    #     if not any(map(anat_filename.endswith, EXT_ANAT.split('|'))):
    #         if isForcing:
    #             logging.info('Reading "{0}" as a {1} file.'.format(anat_filename.split("/")[-1], EXT_ANAT))
    #         else:
    #             parser.error('Anatomy file must be one of {1}!'.format(EXT_ANAT))

    #     if not os.path.isfile(anat_filename):
    #         parser.error('"{0}" must be an existing file!'.format(anat_filename))

    #TODO: Consider different anat, space.
    hdr = {}
    hdr[Header.DIMENSIONS] = (1, 1, 1)
    hdr[Header.ORIGIN] = (1, 1, 1)
    hdr[Header.
        NB_FIBERS] = 0  # The actual number of streamlines will be added later.

    #Merge inputs to output
    inputs = (in_format(in_filename)
              for in_filename, in_format in zip(in_filenames, inFormats))
    output = outFormat.create(out_filename, hdr)
    tractconverter.merge(inputs, output)
예제 #3
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filename = args.input
    out_filename = args.output
    anat_filename = args.anat
    isForcing = args.isForce
    isVerbose = args.isVerbose

    if isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if not os.path.isfile(in_filename):
        parser.error('"{0}" must be an existing file!'.format(in_filename))

    if not tractconverter.is_supported(in_filename):
        parser.error('Input file must be one of {0}!'.format(",".join(
            FORMATS.keys())))

    if not tractconverter.is_supported(out_filename):
        parser.error('Output file must be one of {0}!'.format(",".join(
            FORMATS.keys())))

    if os.path.isfile(out_filename):
        if isForcing:
            if out_filename == in_filename:
                parser.error(
                    'Cannot use the same name for input and output files. Conversion would fail.'
                )
            else:
                logging.info('Overwriting "{0}".'.format(out_filename))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(
                out_filename))

    inFormat = tractconverter.detect_format(in_filename)
    outFormat = tractconverter.detect_format(out_filename)

    #if inFormat == outFormat:
    #    parser.error('Input and output must be from different types!'.format(",".join(FORMATS.keys())))

    if anat_filename is not None:
        if not any(map(anat_filename.endswith, EXT_ANAT.split('|'))):
            if isForcing:
                logging.info('Reading "{0}" as a {1} file.'.format(
                    anat_filename.split("/")[-1], EXT_ANAT))
            else:
                parser.error(
                    'Anatomy file must be one of {1}!'.format(EXT_ANAT))

        if not os.path.isfile(anat_filename):
            parser.error(
                '"{0}" must be an existing file!'.format(anat_filename))

    #Convert input to output
    input = inFormat(in_filename, anat_filename)
    output = outFormat.create(out_filename, input.hdr, anat_filename)
    tractconverter.convert(input, output)
예제 #4
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    #####################################
    # Applying options                  #
    #####################################
    if args.planes_for_stats:
        planes_for_stats = eval(args.planes_for_stats)
    else:
        planes_for_stats = None

    #####################################
    # Checking if the files exist       #
    #####################################
    for myFile in [
            args.tracts_filename, args.centroid_filename, args.ref_anat_name
    ]:
        if not os.path.isfile(myFile):
            parser.error('"{0}" must be a file!'.format(myFile))

    if args.output_name:
        if os.path.exists(args.output_name):
            print(args.output_name, " already exist and will be overwritten.")

    #####################################
    # Loading tracts                    #
    #####################################
    tract_format = tc.detect_format(args.tracts_filename)
    tract = tract_format(args.tracts_filename, anatFile=args.ref_anat_name)
    streamlines = [i for i in tract]

    centroid_format = tc.detect_format(args.centroid_filename)
    tmp = centroid_format(args.centroid_filename, anatFile=args.ref_anat_name)
    centroid = [i for i in tmp]  # should contain only one
    if len(centroid) > 1:
        print(
            'Centroid should contain only one streamline. Here, the file contains more.'
        )
        print('The first streamline will be used as the centroid.')
    centroid = centroid[0]

    #####################################
    # Loading anat                      #
    # Preparing mask                    #
    #####################################
    anat = nib.load(args.ref_anat_name)
    affine = anat.get_affine()
    shape = anat.get_shape()
    if args.mask:
        anat = anat.get_data()
    else:
        anat = np.ones(shape)

    hitpoints = compute_cross_sections(streamlines, centroid)
    compute_stats(hitpoints, anat, affine, args.output_name, planes_for_stats)
예제 #5
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filenames = args.input
    out_filename = args.output
    #anat_filename = args.anat
    isForcing = args.isForce
    isVerbose = args.isVerbose

    if isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    for in_filename in in_filenames:
        if not os.path.isfile(in_filename):
            parser.error('"{0}" must be an existing file!'.format(in_filename))

        if not tractconverter.is_supported(in_filename):
            parser.error('Input file must be one of {0}!'.format(",".join(FORMATS.keys())))

    if not tractconverter.is_supported(out_filename):
        parser.error('Output file must be one of {0}!'.format(",".join(FORMATS.keys())))

    if os.path.isfile(out_filename):
        if isForcing:
            if any(in_name == out_filename for in_name in in_filenames):
                parser.error('Cannot output to a file which is also an input file ({0}).'.format(out_filename))
            else:
                logging.info('Overwriting "{0}".'.format(out_filename))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(out_filename))

    inFormats = [tractconverter.detect_format(in_filename) for in_filename in in_filenames]
    outFormat = tractconverter.detect_format(out_filename)

    # if anat_filename is not None:
    #     if not any(map(anat_filename.endswith, EXT_ANAT.split('|'))):
    #         if isForcing:
    #             logging.info('Reading "{0}" as a {1} file.'.format(anat_filename.split("/")[-1], EXT_ANAT))
    #         else:
    #             parser.error('Anatomy file must be one of {1}!'.format(EXT_ANAT))

    #     if not os.path.isfile(anat_filename):
    #         parser.error('"{0}" must be an existing file!'.format(anat_filename))


    #TODO: Consider different anat, space.
    hdr = {}
    hdr[Header.DIMENSIONS] = (1,1,1)
    hdr[Header.ORIGIN] = (1,1,1)
    hdr[Header.NB_FIBERS] = 0  # The actual number of streamlines will be added later.

    #Merge inputs to output
    inputs = (in_format(in_filename) for in_filename, in_format in zip(in_filenames, inFormats))
    output = outFormat.create(out_filename, hdr)
    tractconverter.merge(inputs, output)
예제 #6
0
def compression_wrapper(tract_filename, out_filename, error_rate):
    tracts_format = tc.detect_format(tract_filename)
    tracts_file = tracts_format(tract_filename)

    out_hdr = tracts_file.hdr
    out_format = tc.detect_format(out_filename)
    out_tracts = out_format.create(out_filename, out_hdr)

    for s in tracts_file:
        # TODO we should chunk this.
        out_tracts += np.array(compress_streamlines(list([s]), error_rate))

    out_tracts.close()
예제 #7
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filename = args.input
    out_filename = args.output
    anat_filename = args.anat
    isForcing = args.isForce
    isVerbose = args.isVerbose

    if isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if not os.path.isfile(in_filename):
        parser.error('"{0}" must be an existing file!'.format(in_filename))

    if not tractconverter.is_supported(in_filename):
        parser.error("Input file must be one of {0}!".format(",".join(FORMATS.keys())))

    if not tractconverter.is_supported(out_filename):
        parser.error("Output file must be one of {0}!".format(",".join(FORMATS.keys())))

    if os.path.isfile(out_filename):
        if isForcing:
            if out_filename == in_filename:
                parser.error("Cannot use the same name for input and output files. Conversion would fail.")
            else:
                logging.info('Overwriting "{0}".'.format(out_filename))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(out_filename))

    inFormat = tractconverter.detect_format(in_filename)
    outFormat = tractconverter.detect_format(out_filename)

    # if inFormat == outFormat:
    #    parser.error('Input and output must be from different types!'.format(",".join(FORMATS.keys())))

    if anat_filename is not None:
        if not any(map(anat_filename.endswith, EXT_ANAT.split("|"))):
            if isForcing:
                logging.info('Reading "{0}" as a {1} file.'.format(anat_filename.split("/")[-1], EXT_ANAT))
            else:
                parser.error("Anatomy file must be one of {1}!".format(EXT_ANAT))

        if not os.path.isfile(anat_filename):
            parser.error('"{0}" must be an existing file!'.format(anat_filename))

    # Convert input to output
    input = inFormat(in_filename, anat_filename)
    output = outFormat.create(out_filename, input.hdr, anat_filename)
    tractconverter.convert(input, output)
예제 #8
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    np.random.seed(int(time.time()))
    in_filename = args.input
    out_filename = args.output
    isForcing = args.isForce
    isVerbose = args.isVerbose

    if isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if not tractconverter.is_supported(args.input):
        parser.error('Input file must be one of {0}!'.format(",".join(
            tractconverter.FORMATS.keys())))

    inFormat = tractconverter.detect_format(in_filename)
    outFormat = tractconverter.detect_format(out_filename)

    if not inFormat == outFormat:
        parser.error('Input and output must be of the same types!'.format(
            ",".join(tractconverter.FORMATS.keys())))

    if os.path.isfile(args.output):
        if args.isForce:
            logging.info('Overwriting "{0}".'.format(out_filename))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(
                out_filename))

    tract = inFormat(in_filename)
    streamlines = [i for i in tract]

    if args.non_fixed_seed:
        rng = np.random.RandomState()
    else:
        rng = None

    results = subsample_streamlines(streamlines, args.minL, args.maxL, args.n,
                                    args.npts, args.arclength, rng)

    logging.info('"{0}" contains {1} streamlines.'.format(
        out_filename, len(results)))

    hdr = tract.hdr
    hdr[tractconverter.formats.header.Header.NB_FIBERS] = len(results)

    output = outFormat.create(out_filename, hdr)
    output += results
예제 #9
0
def save_streamlines_tractquerier(streamlines, ref_filename, output_filename):
    """

    Parameters
    ----------
    streamlines: list
        List of tuples (3D points, scalars, properties).
    ref_filename: str
        File name of a reference image.
    output_filename: str
        File name to save the streamlines to use in the tract querier.

    Return
    ------
    """

    tracts_format = tc.detect_format(output_filename)

    if tracts_format not in [tc.formats.trk.TRK, tc.formats.tck.TCK]:
        raise ValueError("Invalid output streamline file format " +
                         "(must be trk or tck): {0}".format(output_filename))

    hdr = tc.formats.header.get_header_from_anat(ref_filename)

    # This currently creates an invalid .tck file, since they are partly transformed.
    # At least, not having the anatFile param applies an identity transform to the
    # streamlines when saved.
    # Will be fixed with Guillaume Theaud's modifications and
    # @marccote nibabel branch.
    out_tracts = tracts_format.create(output_filename, hdr)
    origin = hdr[tc.formats.header.Header.VOXEL_TO_WORLD][:3, 3]
    tracts = [s[0] + origin for s in streamlines]

    out_tracts += tracts
    out_tracts.close()
예제 #10
0
def validate_coordinates(anat, streamlines, nifti_compliant=True):
    # Check if all points in the tracts are inside the image volume.
    ref_img = nb.load(anat)
    voxel_dim = ref_img.get_header()['pixdim'][1:4]

    if nifti_compliant:
        shift_factor = voxel_dim * 0.5
    else:
        shift_factor = voxel_dim * 0.0

    tract_file = streamlines
    if isinstance(streamlines, six.string_types):
        tc_format = tc.detect_format(streamlines)
        tract_file = tc_format(streamlines, anatFile=anat)

    # TODO what check to do for .vtk?
    if isinstance(tract_file, tc.formats.tck.TCK) \
       or isinstance(tract_file, tc.formats.trk.TRK):
        for s in tract_file:
            strl = np.array(s + shift_factor)
            if np.any(strl < 0):
                return False
    else:
        raise TypeError("This function currently only supports TCK and TRK.")

    return True
예제 #11
0
def save_streamlines_fibernavigator(streamlines, ref_filename,
                                    output_filename):
    """

    Parameters
    ----------
    streamlines: list
        List of tuples (3D points, scalars, properties).
    ref_filename: str
        File name of a reference image.
    output_filename: str
        File name to save the streamlines to use in the FiberNavigator.

    Return
    ------
    """
    tracts_format = tc.detect_format(output_filename)

    if tracts_format not in [tc.formats.trk.TRK, tc.formats.tck.TCK]:
        raise ValueError("Invalid output streamline file format " +
                         "(must be trk or tck): {0}".format(output_filename))

    hdr = tc.formats.header.get_header_from_anat(ref_filename)

    # anatFile is only important for .tck for now, but has no negative
    # impact on other formats.
    out_tracts = tracts_format.create(output_filename,
                                      hdr,
                                      anatFile=ref_filename)

    tracts = [s[0] for s in streamlines]

    out_tracts += tracts
    out_tracts.close()
예제 #12
0
def _lps_allowed(tracts_filename):
    tracts_format = tc.detect_format(tracts_filename)
    tracts_file = tracts_format(tracts_filename)

    if isinstance(tracts_file, tc.formats.vtk.VTK):
        return True

    return False
예제 #13
0
파일: streamlines.py 프로젝트: BIG-S2/PSC
def is_trk(streamlines_filename):
    tracts_format = tc.detect_format(streamlines_filename)
    tracts_file = tracts_format(streamlines_filename)

    if isinstance(tracts_file, tc.formats.trk.TRK):
        return True

    return False
예제 #14
0
def main():

    parser = build_args_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)
    logging.warning('The script scil_substract_streamlines is deprecated. '
                    'Use scil_streamlines_math instead.')

    if os.path.isfile(args.output):
        if args.force:
            logging.info('Overwriting {0}.'.format(args.output))
        else:
            parser.error(
                '{0} already exist! Use -f to overwrite it.'
                .format(args.output))

    # The first filename contains the streamlines from which all others are
    # substracted.
    logging.info(
        'Loading streamlines from file {0} ...'.format(args.input))
    tract_format = tractconverter.detect_format(args.input)
    streamlines = list(tract_format(args.input))

    # All the other filenames contain the streamlines to be removed.
    streamlines_to_remove = []
    for filename in args.remove:
        logging.info(
            'Loading streamlines from file {0} ...'.format(filename))
        tract_format = tractconverter.detect_format(filename)
        streamlines_to_remove.append(tract_format(filename))

    # Remove the streamlines in place.
    substract_streamlines(
        streamlines,
        itertools.chain(*streamlines_to_remove))

    # Save the new streamlines.
    logging.info('Saving remaining streamlines ...')
    tract_format = tractconverter.detect_format(args.input)
    input_tract = tract_format(args.input)
    hdr = input_tract.hdr
    hdr[tractconverter.formats.header.Header.NB_FIBERS] = len(streamlines)
    output = tract_format.create(args.output, hdr)
    output += streamlines
예제 #15
0
def format_needs_orientation(tract_fname):
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    if isinstance(tracts_file, tc.formats.vtk.VTK):
        return True

    return False
예제 #16
0
def guess_orientation(tract_fname):
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    if isinstance(tracts_file, tc.formats.tck.TCK):
        return 'RAS'

    return 'Unknown'
예제 #17
0
def format_needs_orientation(tract_fname):
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    if isinstance(tracts_file, tc.formats.vtk.VTK):
        return True

    return False
예제 #18
0
def guess_orientation(tract_fname):
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    if isinstance(tracts_file, tc.formats.tck.TCK):
        return 'RAS'

    return 'Unknown'
예제 #19
0
def _is_tracts_space_valid(tracts_filename, lps_oriented):
    tracts_format = tc.detect_format(tracts_filename)
    tracts_file = tracts_format(tracts_filename)

    # Compute boundaries of volume
    if lps_oriented:
        required_mins = np.array([-1.0, -1.0, -1.0])
        required_maxs = np.array([179.0, 215.0, 179.0])
    else:
        required_mins = np.array([-179.0, -215.0, -1.0])
        required_maxs = np.array([1.0, 1.0, 179.0])

    # We compute them directly in the loop inside the format dependent code
    # to avoid 2 loops and to avoid loading everything in memory.
    minimas = []
    maximas = []

    # Load tracts
    if isinstance(tracts_file, tc.formats.vtk.VTK) \
        or isinstance(tracts_file, tc.formats.tck.TCK):
        for s in tracts_file:
            minimas.append(np.min(s, axis=0))
            maximas.append(np.max(s, axis=0))
    elif isinstance(tracts_file, tc.formats.trk.TRK):
         # Use nb.trackvis to read directly in correct space
        try:
            streamlines, _ = nb.trackvis.read(tracts_filename,
                                              as_generator=True,
                                              points_space='rasmm')
        except nb.trackvis.HeaderError as er:
            msg = "\n------ ERROR ------\n\n" +\
                  "TrackVis header is malformed or incomplete.\n" +\
                  "Please make sure all fields are correctly set.\n\n" +\
                  "The error message reported by Nibabel was:\n" +\
                  str(er)
            return msg

        for s in streamlines:
            minimas.append(np.min(s[0], axis=0))
            maximas.append(np.max(s[0], axis=0))

    global_min = np.min(minimas, axis=0)
    global_max = np.max(maximas, axis=0)

    if np.all(global_min > required_mins) and \
        np.all(global_max < required_maxs):
        return "Tracts seem to be in the correct space"
    elif isinstance(tracts_file, tc.formats.vtk.VTK) and\
         np.all(global_min * np.array([-1.0, -1.0, 1.0]) > required_mins) \
         and np.all(global_max * np.array([-1.0, -1.0, 1.0]) < required_maxs):
        return "Tracts seem to be reverted. Did you use the --lps flag?\n" +\
                "If so, it means the tracts are not in the correct space."

    return "Tracts do not seem to be in the correct space.\n\n" + \
           _print_required_and_found(required_mins, required_maxs,
                                     global_min, global_max, lps_oriented)
예제 #20
0
def _is_format_supported(tracts_filename):
    tracts_format = tc.detect_format(tracts_filename)
    tracts_file = tracts_format(tracts_filename)

    if isinstance(tracts_file, tc.formats.tck.TCK) \
        or isinstance(tracts_file, tc.formats.trk.TRK) \
        or isinstance(tracts_file, tc.formats.vtk.VTK):
        return True

    return False
예제 #21
0
def label_streamlines(streamlines, labels, labels_Value, affine, hdr, f_name,
                      data_path):

    cc_slice = labels == labels_Value
    cc_streamlines = utils.target(streamlines, labels, affine=affine)
    cc_streamlines = list(cc_streamlines)

    other_streamlines = utils.target(streamlines,
                                     cc_slice,
                                     affine=affine,
                                     include=False)
    other_streamlines = list(other_streamlines)
    assert len(other_streamlines) + len(cc_streamlines) == len(streamlines)

    print("num of roi steamlines is %d", len(cc_streamlines))

    # Make display objects
    color = line_colors(cc_streamlines)
    cc_streamlines_actor = fvtk.line(cc_streamlines,
                                     line_colors(cc_streamlines))
    cc_ROI_actor = fvtk.contour(cc_slice,
                                levels=[1],
                                colors=[(1., 1., 0.)],
                                opacities=[1.])

    # Add display objects to canvas
    r = fvtk.ren()
    fvtk.add(r, cc_streamlines_actor)
    fvtk.add(r, cc_ROI_actor)

    # Save figures
    fvtk.record(r, n_frames=1, out_path=f_name + '_roi.png', size=(800, 800))
    fvtk.camera(r, [-1, 0, 0], [0, 0, 0], viewup=[0, 0, 1])
    fvtk.record(r, n_frames=1, out_path=f_name + '_roi.png', size=(800, 800))
    """"""

    csd_streamlines_trk = ((sl, None, None) for sl in cc_streamlines)
    csd_sl_fname = f_name + '_roi_streamline.trk'
    nib.trackvis.write(csd_sl_fname,
                       csd_streamlines_trk,
                       hdr,
                       points_space='voxel')
    #nib.save(nib.Nifti1Image(FA, img.get_affine()), 'FA_map2.nii.gz')
    print('Saving "_roi_streamline.trk" sucessful.')

    import tractconverter as tc
    input_format = tc.detect_format(csd_sl_fname)
    input = input_format(csd_sl_fname)
    output = tc.FORMATS['vtk'].create(csd_sl_fname + ".vtk", input.hdr)
    tc.convert(input, output)

    return cc_streamlines
예제 #22
0
def get_tract_count(streamlines):
    if isinstance(streamlines, six.string_types):
        tc_format = tc.detect_format(streamlines)
        tract_file = tc_format(streamlines)
        tract_count = tract_file.hdr[tract_header.NB_FIBERS]
    elif isinstance(streamlines, list):
        tract_count = len(streamlines)
    # Need to do it like this since the is no parent class in the formats.
    elif isinstance(streamlines, tc.formats.tck.TCK) \
        or isinstance(streamlines, tc.formats.trk.TRK) \
        or isinstance(streamlines, tc.formats.vtk.VTK):
        tract_count = streamlines.hdr[tract_header.NB_FIBERS]

    return tract_count
예제 #23
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filename = args.input

    if not os.path.isfile(in_filename):
        parser.error('"{0}" must be an existing file!'.format(in_filename))

    if not tractconverter.is_supported(in_filename):
        parser.error('Input file must be one of {0}!'.format(",".join(FORMATS.keys())))

    inFormat = tractconverter.detect_format(in_filename)

    #Print info about the input file.
    print inFormat(in_filename, None)
def label_streamlines(streamlines,labels,labels_Value,affine,hdr,f_name,data_path):  
      
    cc_slice=labels==labels_Value
    cc_streamlines = utils.target(streamlines, labels, affine=affine)
    cc_streamlines = list(cc_streamlines)

    other_streamlines = utils.target(streamlines, cc_slice, affine=affine,
                                 include=False)
    other_streamlines = list(other_streamlines)
    assert len(other_streamlines) + len(cc_streamlines) == len(streamlines)
    

    print ("num of roi steamlines is %d",len(cc_streamlines))
    

    # Make display objects
    color = line_colors(cc_streamlines)
    cc_streamlines_actor = fvtk.line(cc_streamlines, line_colors(cc_streamlines))
    cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)],
                            opacities=[1.])

    # Add display objects to canvas
    r = fvtk.ren()
    fvtk.add(r, cc_streamlines_actor)
    fvtk.add(r, cc_ROI_actor)

    # Save figures
    fvtk.record(r, n_frames=1, out_path=f_name+'_roi.png',
            size=(800, 800))
    fvtk.camera(r, [-1, 0, 0], [0, 0, 0], viewup=[0, 0, 1])
    fvtk.record(r, n_frames=1, out_path=f_name+'_roi.png',
            size=(800, 800))
    """"""

    csd_streamlines_trk = ((sl, None, None) for sl in cc_streamlines)
    csd_sl_fname = f_name+'_roi_streamline.trk'
    nib.trackvis.write(csd_sl_fname, csd_streamlines_trk, hdr, points_space='voxel')
    #nib.save(nib.Nifti1Image(FA, img.get_affine()), 'FA_map2.nii.gz')
    print('Saving "_roi_streamline.trk" sucessful.')

    import tractconverter as tc
    input_format=tc.detect_format(csd_sl_fname)
    input=input_format(csd_sl_fname)
    output=tc.FORMATS['vtk'].create(csd_sl_fname+".vtk",input.hdr)
    tc.convert(input,output)
    
    return cc_streamlines
예제 #25
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filename = args.input

    if not os.path.isfile(in_filename):
        parser.error('"{0}" must be an existing file!'.format(in_filename))

    if not tractconverter.is_supported(in_filename):
        parser.error('Input file must be one of {0}!'.format(",".join(
            FORMATS.keys())))

    inFormat = tractconverter.detect_format(in_filename)

    #Print info about the input file.
    print inFormat(in_filename, None)
예제 #26
0
def compute_affine_for_dipy_functions(anat, streamlines):
    # Determine if we need to send an identity affine or the real
    # affine. This depends of the space in which streamlines are given by
    # the TractConverter. If we are loading a TCK or TRK file, the streamlines
    # will be aligned with a grid starting at the origin of the reference frame
    # in millimetric space. In that case, send a "scale" identity to density_map
    # to avoid any further transform.
    ref_img = nib.load(anat)
    voxel_dim = ref_img.get_header()['pixdim'][1:4]
    affine_for_dipy = ref_img.get_affine()

    tract_file = streamlines
    if isinstance(streamlines, six.string_types):
        tc_format = tc.detect_format(streamlines)
        tract_file = tc_format(streamlines, anatFile=anat)

    if isinstance(tract_file, tc.formats.tck.TCK) \
       or isinstance(tract_file, tc.formats.trk.TRK):
        affine_for_dipy = np.eye(4)
        affine_for_dipy[:3, :3] *= np.asarray(voxel_dim)

    return affine_for_dipy
예제 #27
0
def compute_max_values(tracts_filename):
    tracts_format = tc.detect_format(tracts_filename)
    tracts_file = tracts_format(tracts_filename)

    # We compute them directly in the loop inside the format dependent code
    # to avoid 2 loops and to avoid loading everything in memory.
    minimas = []
    maximas = []

    # Load tracts
    if isinstance(tracts_file, tc.formats.vtk.VTK) \
       or isinstance(tracts_file, tc.formats.tck.TCK):
        for s in tracts_file:
            minimas.append(np.min(s, axis=0))
            maximas.append(np.max(s, axis=0))
    elif isinstance(tracts_file, tc.formats.trk.TRK):
        # Use nb.trackvis to read directly in correct space
        try:
            streamlines, _ = nb.trackvis.read(tracts_filename,
                                              as_generator=True)
        except nb.trackvis.HeaderError as er:
            msg = "\n------ ERROR ------\n\n" +\
                  "TrackVis header is malformed or incomplete.\n" +\
                  "Please make sure all fields are correctly set.\n\n" +\
                  "The error message reported by Nibabel was:\n" +\
                  str(er)
            return msg

        for s in streamlines:
            minimas.append(np.min(s[0], axis=0))
            maximas.append(np.max(s[0], axis=0))

    global_min = np.min(minimas, axis=0)
    global_max = np.max(maximas, axis=0)

    print("Min: {0}".format(global_min))
    print("Max: {0}".format(global_max))
예제 #28
0
def count(tract_filename, roi_anat, roi_idx_range):
    roi_img = nib.load(roi_anat)
    voxel_dim = roi_img.get_header()['pixdim'][1:4]
    anat_dim = roi_img.get_header().get_data_shape()

    # Detect the format of the tracts file.
    # IF TRK, load and shift
    # ELSE, load
    tracts_format = tc.detect_format(tract_filename)
    tracts_file = tracts_format(tract_filename, anatFile=roi_anat)

    if tracts_format is tc.FORMATS["trk"]:
        tracts = np.array([s - voxel_dim / 2. for s in tracts_file.load_all()])
    else:
        tracts = np.array([s for s in tracts_file])

    _, tes = track_counts(tracts, anat_dim, voxel_dim, True)

    # If the data is a 4D volume with only one element in 4th dimension,
    # this will make it 3D, to correctly work with the tes variable.
    roi_data = roi_img.get_data().squeeze()

    if len(roi_data.shape) > 3:
        raise ValueError('Tract counting will fail with an anatomy of ' +
                         'more than 3 dimensions.')

    roi_counts = []

    for roi_idx in roi_idx_range:
        roi_vox_idx = izip(*np.where(roi_data == roi_idx))
        tractIdx_per_voxel = [set(tes.get(idx, [])) for idx in roi_vox_idx]

        if len(tractIdx_per_voxel) > 0:
            unique_streamline_idx = set.union(*tractIdx_per_voxel)
            roi_counts.append((roi_idx, len(unique_streamline_idx)))

    return roi_counts
예제 #29
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    in_filename = args.input
    out_filename_loops = args.output_loops
    out_filename_clean = args.output_clean
    streamlines_c = []
    loops = []

    if not tc.is_supported(args.input):
        parser.error('Input file must be one of {0}!'.format(",".join(
            tc.FORMATS.keys())))

    in_format = tc.detect_format(in_filename)
    out_format_clean = tc.detect_format(out_filename_clean)

    if not in_format == out_format_clean:
        parser.error('Input and output must be of the same types!'.format(
            ",".join(tc.FORMATS.keys())))

    if args.output_loops:
        out_format_loops = tc.detect_format(out_filename_loops)
        if not in_format == out_format_loops:
            parser.error(
                'Input and output loops must be of the same types!'.format(
                    ",".join(tc.FORMATS.keys())))
        if os.path.isfile(args.output_loops):
            if args.isForce:
                logging.info('Overwriting "{0}".'.format(out_filename_loops))
            else:
                parser.error(
                    '"{0}" already exist! Use -f to overwrite it.'.format(
                        out_filename_loops))

    if os.path.isfile(args.output_clean):
        if args.isForce:
            logging.info('Overwriting "{0}".'.format(out_filename_clean))
        else:
            parser.error('"{0}" already exist! Use -f to overwrite it.'.format(
                out_filename_clean))

    if args.threshold <= 0:
        parser.error('"{0}"'.format(args.threshold) + 'must be greater than 0')

    if args.angle <= 0:
        parser.error('"{0}"'.format(args.angle) + 'must be greater than 0')

    tract = in_format(in_filename)
    streamlines = [i for i in tract]

    if len(streamlines) > 1:
        streamlines_c, loops = remove_loops_and_sharp_turns(
            streamlines, args.QB, args.angle, args.threshold)
    else:
        parser.error('Zero or one streamline in ' + '{0}'.format(in_filename) +
                     '. The file must have more than one streamline.')

    hdr = tract.hdr.copy()
    nb_points_init = hdr[tc.formats.header.Header.NB_POINTS]
    nb_points_clean = 0

    if len(streamlines_c) > 0:
        hdr[tc.formats.header.Header.NB_FIBERS] = len(streamlines_c)
        if in_format is tc.formats.vtk.VTK:
            for s in streamlines_c:
                nb_points_clean += len(s)
            hdr[tc.formats.header.Header.NB_POINTS] = nb_points_clean
        output_clean = out_format_clean.create(out_filename_clean, hdr)
        output_clean += streamlines_c
        output_clean.close()
    else:
        logging.warning("No clean streamlines in {0}".format(args.input))

    if len(loops) == 0:
        logging.warning("No loops in {0}".format(args.input))

    if args.output_loops and len(loops) > 0:
        hdr[tc.formats.header.Header.NB_FIBERS] = len(loops)
        if in_format is tc.formats.vtk.VTK:
            hdr[tc.formats.header.Header.
                NB_POINTS] = nb_points_init - nb_points_clean
        output_loops = out_format_loops.create(out_filename_loops, hdr)
        output_loops += loops
        output_loops.close()
예제 #30
0
def _get_tracts_over_grid(tract_fname, ref_anat_fname, tract_attributes,
                           start_at_corner=True):
    # TODO move to only get the attribute
    # Tract_attributes is a dictionary containing various information
    # about a dataset. Currently using:
    # - "orientation" (should be LPS or RAS)
    tracts_format = tc.detect_format(tract_fname)
    tracts_file = tracts_format(tract_fname)

    # Get information on the supporting anatomy
    ref_img = nb.load(ref_anat_fname)

    index_to_world_affine = ref_img.get_header().get_best_affine()

    if isinstance(tracts_file, tc.formats.vtk.VTK):
        # For VTK files, we need to check the orientation.
        # Considered to be in world space. Use the orientation to correct the
        # affine to bring back to voxel.
        # Since the affine from Nifti goes from voxel to RAS, we need to
        # *-1 the 2 first rows if we are in LPS.
        orientation = tract_attributes.get("orientation", None)
        if orientation is None:
            raise AttributeError('Missing the "orientation" attribute for VTK')
        elif orientation == "NOT_FOUND":
            raise ValueError('Invalid value of "NOT_FOUND" for orientation')
        elif orientation == "LPS":
            index_to_world_affine[0,:] *= -1.0
            index_to_world_affine[1,:] *= -1.0

    # Transposed for efficient computations later on.
    index_to_world_affine = index_to_world_affine.T.astype('<f4')
    world_to_index_affine = linalg.inv(index_to_world_affine)

    # Load tracts
    if isinstance(tracts_file, tc.formats.tck.TCK)\
        or isinstance(tracts_file, tc.formats.vtk.VTK):
        if start_at_corner:
            shift = 0.5
        else:
            shift = 0.0

        for s in tracts_file:
            transformed_s = np.dot(c_[s, np.ones([s.shape[0], 1], dtype='<f4')],
                                   world_to_index_affine)[:, :-1] + shift
            yield transformed_s
    elif isinstance(tracts_file, tc.formats.trk.TRK):
         # Use nb.trackvis to read directly in correct space
         # TODO this should be made more robust, using
         # all fields in header.
         # Currently, load in rasmm space, and then bring back to LPS vox
        try:
            streamlines, _ = nb.trackvis.read(tract_fname,
                                              as_generator=True,
                                              points_space='rasmm')
        except nb.trackvis.HeaderError as er:
            print(er)
            raise ValueError("\n------ ERROR ------\n\n" +\
                  "TrackVis header is malformed or incomplete.\n" +\
                  "Please make sure all fields are correctly set.\n\n" +\
                  "The error message reported by Nibabel was:\n" +\
                  str(er))

        if start_at_corner:
            shift = 0.0
        else:
            shift = 0.5

        for s in streamlines:
            transformed_s = np.dot(c_[s[0], np.ones([s[0].shape[0], 1], dtype='<f4')],
                                   world_to_index_affine)[:, :-1] + shift
            yield transformed_s
예제 #31
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()
    param = {}

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.outputTQ:
        filename_parts = os.path.splitext(args.output_file)
        output_filename = filename_parts[0] + '.tq' + filename_parts[1]
    else:
        output_filename = args.output_file

    out_format = tc.detect_format(output_filename)
    if out_format not in [tc.formats.trk.TRK, tc.formats.tck.TCK]:
        parser.error("Invalid output streamline file format (must be trk or " +
                     "tck): {0}".format(output_filename))
        return

    if os.path.isfile(output_filename):
        if args.isForce:
            logging.debug('Overwriting "{0}".'.format(output_filename))
        else:
            parser.error(
                '"{0}" already exists! Use -f to overwrite it.'.format(
                    output_filename))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {0}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={0}mm, maxL={1}mm).'.format(
                args.min_length, args.max_length))

    if not np.any([args.nt, args.npv, args.ns]):
        args.npv = 1

    if args.theta is not None:
        theta = gm.math.radians(args.theta)
    elif args.curvature > 0:
        theta = get_max_angle_from_curvature(args.curvature, args.step_size)
    else:
        theta = gm.math.radians(45)

    if args.mask_interp == 'nn':
        mask_interpolation = 'nearest'
    elif args.mask_interp == 'tl':
        mask_interpolation = 'trilinear'
    else:
        parser.error("--mask_interp has wrong value. See the help (-h).")
        return

    param['random'] = args.random
    param['skip'] = args.skip
    param['algo'] = args.algo
    param['mask_interp'] = mask_interpolation
    param['field_interp'] = 'nearest'
    param['theta'] = theta
    param['sf_threshold'] = args.sf_threshold
    param['sf_threshold_init'] = args.sf_threshold_init
    param['step_size'] = args.step_size
    param['rk_order'] = args.rk_order
    param['max_length'] = args.max_length
    param['min_length'] = args.min_length
    param['max_nbr_pts'] = int(param['max_length'] / param['step_size'])
    param['min_nbr_pts'] = int(param['min_length'] / param['step_size']) + 1
    param['is_single_direction'] = args.is_single_direction
    param['nbr_seeds'] = args.nt if args.nt is not None else 0
    param['nbr_seeds_voxel'] = args.npv if args.npv is not None else 0
    param['nbr_streamlines'] = args.ns if args.ns is not None else 0
    param['max_no_dir'] = int(math.ceil(args.maxL_no_dir / param['step_size']))
    param['is_all'] = False
    param['is_keep_single_pts'] = False
    # r+ is necessary for interpolation function in cython who
    # need read/write right
    param['mmap_mode'] = None if args.isLoadData else 'r+'

    logging.debug('Tractography parameters:\n{0}'.format(param))

    seed_img = nib.load(args.seed_file)
    seed = Seed(seed_img)
    if args.npv:
        param['nbr_seeds'] = len(seed.seeds) * param['nbr_seeds_voxel']
        param['skip'] = len(seed.seeds) * param['skip']
    if len(seed.seeds) == 0:
        parser.error('"{0}" does not have voxels value > 0.'.format(
            args.seed_file))

    mask = BinaryMask(Dataset(nib.load(args.mask_file), param['mask_interp']))

    dataset = Dataset(nib.load(args.peaks_file), param['field_interp'])
    field = MaximaField(dataset, param['sf_threshold'],
                        param['sf_threshold_init'], param['theta'])

    if args.algo == 'det':
        tracker = deterministicMaximaTracker(field, param)
    elif args.algo == 'prob':
        tracker = probabilisticTracker(field, param)
    else:
        parser.error("--algo has wrong value. See the help (-h).")
        return

    start = time.time()
    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warn(
                'You are using an error rate of {}.\n'.format(args.compress) +
                'We recommend setting it between 0.001 and 1.\n' +
                '0.001 will do almost nothing to the tracts while ' +
                '1 will higly compress/linearize the tracts')

        streamlines = track(tracker,
                            mask,
                            seed,
                            param,
                            compress=True,
                            compression_error_threshold=args.compress,
                            nbr_processes=args.nbr_processes,
                            pft_tracker=None)
    else:
        streamlines = track(tracker,
                            mask,
                            seed,
                            param,
                            nbr_processes=args.nbr_processes,
                            pft_tracker=None)

    if args.outputTQ:
        save_streamlines_tractquerier(streamlines, args.seed_file,
                                      output_filename)
    else:
        save_streamlines_fibernavigator(streamlines, args.seed_file,
                                        output_filename)

    str_ave_length = "%.2f" % compute_average_streamlines_length(streamlines)
    str_time = "%.2f" % (time.time() - start)
    logging.debug(
        str(len(streamlines)) + " streamlines, with an average " +
        "length of " + str_ave_length + " mm, done in " + str_time +
        " seconds.")
예제 #32
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    #####################################
    # Checking if the files exist       #
    #####################################
    for myFile in [args.tracts_filename, args.ref_anat_name]:
        if not os.path.isfile(myFile):
            parser.error('"{0}" must be a file!'.format(myFile))

    if os.path.exists(args.output_name):
        print(args.output_name, " already exist and will be overwritten.")

    #####################################
    # Loading tracts                    #
    #####################################
    tracts_format = tc.detect_format(args.tracts_filename)
    tract_file = tracts_format(args.tracts_filename,
                               anatFile=args.ref_anat_name)
    tracts = [i for i in tract_file]
    hdr = tract_file.hdr

    #####################################
    # Checking if needs subsampling     #
    #####################################
    tmp = len(tracts[0])
    if all(len(my_tract) == tmp for my_tract in tracts):
        nb_of_points = tmp
    else:
        nb_of_points = args.nb_of_points

    #####################################
    # Compute QuickBundles             #
    #####################################
    print('Starting the QuickBundles...')
    # This feature tells QuickBundles to resample each streamlines on the fly.
    feature = ResampleFeature(nb_points=nb_of_points)
    # 'qb' is `dipy.segment.clustering.QuickBundles` object.
    qb = QuickBundles(threshold=args.dist_thresh, metric=MDF(feature))
    # 'clusters' is `dipy.segment.clustering.ClusterMap` object.
    clusters = qb.cluster(tracts)
    centroids = clusters.centroids
    print('    --- done. Number of centroids:', len(centroids))
    print('              Number of points per tract:', nb_of_points)
    print('Cluster sizes:', list(map(len, clusters)))

    #####################################
    # Saving                            #
    #####################################
    print('Saving...')
    out_format = tc.detect_format(args.output_name)
    qb_header = hdr
    qb_header[Header.NB_FIBERS] = len(centroids)
    out_centroids = out_format.create(args.output_name,
                                      qb_header,
                                      anatFile=args.ref_anat_name)
    out_centroids += [s for s in centroids]
    out_centroids.close()

    print('    --- done.')
    #converting trk file to vtk file
    input_anatomy_ref = '/home/bao/tiensy/Lauren_registration/data_compare_mapping/anatomy/' + subj + '_data_brain.nii.gz'  
    #it is a link to '/home/bao/Personal/PhD_at_Trento/Research/ALS_Nivedita_Bao/Segmentation_CST_francesca/' + subj + '/DTI/data_brain.nii.gz'
    #cst_vtk_fname = '/home/bao/tiensy/Tractography_Mapping/data/trackvis_tractography/ROI_seg_tvis/ROI_seg_tvis_native/' + subj + '_corticospinal_R_tvis.vtk'
    cst_ext_vtk_fname = '/home/bao/tiensy/Tractography_Mapping/data/trackvis_tractography/ROI_seg_tvis/ROI_seg_tvis_native/' + subj + '_cst_R_tvis_ext.vtk'    
   
        
    in_format = str(".trk")
    out_format = ".vtk"
    #input_file = cst_trk_fname
    #output_file = cst_vtk_fname
    
    input_file = cst_ext_trk_fname
    output_file = cst_ext_vtk_fname
    
    input_format = tractconverter.detect_format(input_file)
    in_put = input_format(input_file, input_anatomy_ref)
    out_put = tractconverter.FORMATS['vtk'].create(output_file, in_put.hdr, input_anatomy_ref)
    tractconverter.convert(in_put, out_put)
    print "Done", output_file
    
    
    
'''

input_path =  '/home/bao/tiensy/Lauren_registration/data_compare_mapping/tractography/'
output_path = '/home/bao/tiensy/Lauren_registration/data_compare_mapping/tractography/'



for id_obj in np.arange(len(sub)):
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    if not os.path.isfile(args.tracts):
        parser.error("Tracts file: {0} does not exist.".format(args.tracts))

    if not os.path.isfile(args.aparc):
        parser.error("Label file: {0} does not exist.".format(args.aparc))

    if not os.path.isfile(args.labels):
        parser.error("Requested region file: {0} does not exist.".format(
            args.labels))

    if not os.path.isfile(args.lut):
        parser.error("Freesurfer LUT file: {0} does not exist.".format(
            args.lut))

    if not os.path.isfile(args.faimage):
        parser.error("FA Image file: {0} does not exist.".format(args.faimage))

    if not os.path.isfile(args.mdimage):
        parser.error("MD Image file: {0} does not exist.".format(args.mdimage))

    # Validate that tracts can be processed
    if not validate_coordinates(args.aparc, args.tracts, nifti_compliant=True):
        parser.error("The tracts file contains points that are invalid.\n" +
                     "Use the remove_invalid_coordinates.py script to clean.")

    # Load label image
    labels_img = nib.load(args.aparc)
    full_labels = labels_img.get_data().astype('int')

    # Load fibers
    tract_format = tc.detect_format(args.tracts)
    tract = tract_format(args.tracts, args.aparc)

    affine = compute_affine_for_dipy_functions(args.aparc, args.tracts)

    #load FA and MD image
    fa_img = nib.load(args.faimage)
    fa_data = fa_img.get_data()

    md_img = nib.load(args.mdimage)
    md_data = md_img.get_data()

    # ========= processing streamlines =================
    fiberlen_range = np.asarray([args.minlen, args.maxlen])

    streamlines = [t for t in tract]
    print "Subject " + args.sub_id + " has " + str(
        len(streamlines)) + " raw streamlines."

    f_streamlines = []  #filtered streamlines
    lenrecord = []
    idx = 0
    for sl in streamlines:
        # Avoid streamlines having only one point, as they crash the
        # Dipy connectivity matrix function.
        if sl.shape[0] > 1:
            flen = length(sl)
            # get fibers having length between 20mm and 200mm
            if (flen > fiberlen_range[0]) & (flen < fiberlen_range[1]):
                f_streamlines.append(sl)
                lenrecord.append(flen)
                idx = idx + 1

    print "Subject " + args.sub_id + " has " + str(
        idx) + " streamlines with lengths between " + str(
            args.minlen) + " and " + str(args.maxlen) + "."

    # ============= process the parcellation =====================
    dilation_para = np.array([args.dilation_dist, args.dilation_windsize])

    # Compute the mapping from label name to label id
    label_id_mapping = compute_labels_map(args.lut)

    # Find which labels were requested by the user.
    requested_labels_mapping = compute_requested_labels(
        args.labels, label_id_mapping)

    # Filter to keep only needed ones
    filtered_labels = np.zeros(full_labels.shape, dtype='int')
    for label_val in requested_labels_mapping:
        if sum(sum(sum(full_labels == label_val))) == 0:
            print label_val
            print requested_labels_mapping[label_val]

        filtered_labels[full_labels == label_val] = label_val

    #cortex band dilation
    dilated_labels = cortexband_dilation_wm(filtered_labels, full_labels,
                                            dilation_para)

    # Reduce the range of labels to avoid a sparse matrix,
    # because the ids of labels can range from 0 to the 12000's.
    reduced_labels, labels_lut = dpu.reduce_labels(filtered_labels)
    reduced_dilated_labels, labels_lut = dpu.reduce_labels(dilated_labels)

    # Compute connectivity matrix and extract the fibers
    M, grouping = nconnectivity_matrix(f_streamlines,
                                       reduced_dilated_labels,
                                       fiberlen_range,
                                       args.cnpoint,
                                       affine=affine,
                                       symmetric=True,
                                       return_mapping=True,
                                       mapping_as_streamlines=True)

    Msize = len(M)
    CM_before_outlierremove = M[1:, 1:]
    nstream_bf = np.sum(CM_before_outlierremove)
    print args.sub_id + ' ' + str(
        nstream_bf
    ) + ' streamlines in the connectivity matrix before outlier removal.'

    #===================== process the streamlines =============
    print 'Processing streamlines to remove outliers ..............'

    outlier_para = 3
    average_thrd = 8

    M_after_ourlierremove = np.zeros((Msize, Msize))
    #downsample streamlines
    cell_streamlines = []
    cell_id = []
    for i in range(1, Msize):
        for j in range(i + 1, Msize):
            tmp_streamlines = grouping[i, j]
            tmp_streamlines = list(tmp_streamlines)
            #downsample
            tmp_streamlines_downsampled = [
                downsample(s, 100) for s in tmp_streamlines
            ]
            #remove outliers, we need to rewrite the QuickBundle method to speed up this process

            qb = QuickBundles(threshold=average_thrd)
            clusters = qb.cluster(tmp_streamlines_downsampled)
            outlier_clusters = clusters < outlier_para  #small clusters
            nonoutlier_clusters = clusters[np.logical_not(outlier_clusters)]

            tmp_nonoutlier_index = []
            for tmp_cluster in nonoutlier_clusters:
                tmp_nonoutlier_index = tmp_nonoutlier_index + tmp_cluster.indices

            clean_streamline_downsampled = [
                tmp_streamlines_downsampled[ind]
                for ind in tmp_nonoutlier_index
            ]
            cell_streamlines.append(clean_streamline_downsampled)
            cell_id.append([i, j])
            M_after_ourlierremove[i, j] = len(clean_streamline_downsampled)

    CM_after_ourlierremove = M_after_ourlierremove[1:, 1:]
    nstream_bf = np.sum(CM_after_ourlierremove)
    print args.sub_id + ' ' + str(
        nstream_bf
    ) + ' streamlines in the connectivity matrix after outlier removal.'

    #save streamlines and count matrix

    cmCountMatrix_fname = args.sub_id + "_" + args.pre + "_cm_count_raw.mat"
    cmCountMatrix_processed_fname = args.sub_id + "_" + args.pre + "_cm_count_processed.mat"
    cmStreamlineMatrix_fname = args.sub_id + "_" + args.pre + "_cm_streamlines.mat"
    reduced_labels_fname = args.sub_id + "_" + args.pre + "_reduced_labels.nii.gz"
    dilated_labels_fname = args.sub_id + "_" + args.pre + "_dilated_labels.nii.gz"
    RoiInfo_fname = args.sub_id + "_" + args.pre + "_RoiInfo.mat"

    # save the raw count matrix
    CM = M[1:, 1:]
    sio.savemat(cmCountMatrix_fname, {'cm': CM})
    sio.savemat(cmCountMatrix_processed_fname, {'cm': CM_after_ourlierremove})

    # save the streamline matrix
    sio.savemat(cmStreamlineMatrix_fname, {'slines': cell_streamlines})
    sio.savemat(RoiInfo_fname, {'ROIinfo': cell_id})
    print args.sub_id + 'cell_streamlines.mat, ROIinfo.mat has been saved'

    filtered_labels_img = nib.Nifti1Image(filtered_labels,
                                          labels_img.get_affine(),
                                          labels_img.get_header())
    nib.save(filtered_labels_img, reduced_labels_fname)
    print args.sub_id + 'filtered labels have saved'

    dilated_labels_img = nib.Nifti1Image(dilated_labels,
                                         labels_img.get_affine(),
                                         labels_img.get_header())
    nib.save(dilated_labels_img, dilated_labels_fname)
    print args.sub_id + 'dilated labels have saved'

    # ===================== process the streamlines and extract features =============
    cm_fa_curve = fa_extraction_use_cellinput(cell_streamlines,
                                              cell_id,
                                              fa_data,
                                              Msize,
                                              affine=affine)
    (tmp_cm_fa_mean, tmp_cm_fa_max,
     cm_count) = fa_mean_extraction(cm_fa_curve, Msize)

    # extract MD values along the streamlines
    cm_md_curve = fa_extraction_use_cellinput(cell_streamlines,
                                              cell_id,
                                              md_data,
                                              Msize,
                                              affine=affine)
    (tmp_cm_md_mean, tmp_cm_md_max,
     testcm) = fa_mean_extraction(cm_md_curve, Msize)

    #connected surface area
    # extract the connective volume ratio
    (tmp_cm_volumn,
     tmp_cm_volumn_ratio) = rois_connectedvol_cellinput(reduced_labels,
                                                        Msize,
                                                        cell_streamlines,
                                                        cell_id,
                                                        affine=affine)

    #fiber length
    tmp_connectcm_len = rois_fiberlen_cellinput(Msize, cell_streamlines)

    #save cm features
    cm_md_mean = tmp_cm_md_mean[1:, 1:]
    cm_md_max = tmp_cm_md_max[1:, 1:]

    cm_fa_mean = tmp_cm_fa_mean[1:, 1:]
    cm_fa_max = tmp_cm_fa_max[1:, 1:]

    cm_volumn = tmp_cm_volumn[1:, 1:]
    cm_volumn_ratio = tmp_cm_volumn_ratio[1:, 1:]

    connectcm_len = tmp_connectcm_len[1:, 1:]

    sio.savemat(args.pre + "_cm_processed_mdmean_100.mat",
                {'cm_mdmean': cm_md_mean})
    sio.savemat(args.pre + "_cm_processed_mdmax_100.mat",
                {'cm_mdmax': cm_md_max})
    sio.savemat(args.pre + "_cm_processed_famean_100.mat",
                {'cm_famean': cm_fa_mean})
    sio.savemat(args.pre + "_cm_processed_famax_100.mat",
                {'cm_famax': cm_fa_max})
    sio.savemat(args.pre + "_cm_processed_volumn_100.mat",
                {'cm_volumn': cm_volumn})
    sio.savemat(args.pre + "_cm_processed_volumn_ratio_100.mat",
                {'cm_volumn_ratio': cm_volumn_ratio})
    sio.savemat(args.pre + "_cm_processed_volumn_ratio_100.mat",
                {'cm_len': connectcm_len})

    # save the diffusion functions matrix
    cell_fa = []
    for i in range(1, Msize):
        for j in range(i + 1, Msize):
            tmp_fa = cm_fa_curve[i, j]
            tmp_fa = list(tmp_fa)
            cell_fa.append(tmp_fa)

    sio.savemat(args.pre + "_cm_processed_sfa_100.mat", {'sfa': cell_fa})
    print 'cell_fa.mat, fa_roiinfo.mat have been saved'

    cell_md = []
    for i in range(1, Msize):
        for j in range(i + 1, Msize):
            tmp_md = cm_md_curve[i, j]
            tmp_md = list(tmp_md)
            cell_md.append(tmp_md)

    sio.savemat(args.pre + "_cm_processed_smd_100.mat", {'smd': cell_md})
예제 #35
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    if not os.path.isfile(args.tracts):
        parser.error("Tracts file: {0} does not exist.".format(args.tracts))

    # TODO check scilpy supports

    if not os.path.isfile(args.aparc):
        parser.error("Label file: {0} does not exist.".format(args.aparc))

    if not os.path.isfile(args.labels):
        parser.error("Requested region file: {0} does not exist.".format(
            args.labels))

    if not os.path.isfile(args.lut):
        parser.error("Freesurfer LUT file: {0} does not exist.".format(
            args.lut))

    if os.path.isfile(args.out_matrix) and not args.force_overwrite:
        parser.error(
            "Output: {0} already exists. To overwrite, use -f.".format(
                args.out_matrix))

    if os.path.isfile(args.out_row_map) and not args.force_overwrite:
        parser.error(
            "Output: {0} already exists. To overwrite, use -f.".format(
                args.out_row_map))

    if os.path.splitext(args.out_matrix)[1] != ".npy":
        parser.error("Connectivity matrix must be saved in a .npy file.")

    if os.path.splitext(args.out_row_map)[1] != ".pkl":
        parser.error("Mapping must be saved in a .pkl file.")

    # Validate that tracts can be processed
    if not validate_coordinates(args.aparc, args.tracts, nifti_compliant=True):
        parser.error("The tracts file contains points that are invalid.\n" +
                     "Use the remove_invalid_coordinates.py script to clean.")

    # Load labels
    labels_img = nib.load(args.aparc)
    full_labels = labels_img.get_data().astype('int')

    # Compute the mapping from label name to label id
    label_id_mapping = compute_labels_map(args.lut)

    # Find which labels were requested by the user.
    requested_labels_mapping = compute_requested_labels(
        args.labels, label_id_mapping)

    # Filter to keep only needed ones
    filtered_labels = np.zeros(full_labels.shape, dtype='int')
    for label_val in requested_labels_mapping:
        filtered_labels[full_labels == label_val] = label_val

    # Reduce the range of labels to avoid a sparse matrix,
    # because the ids of labels can range from 0 to the 12000's.
    reduced_labels, labels_lut = dpu.reduce_labels(filtered_labels)

    # Load tracts
    tract_format = tc.detect_format(args.tracts)
    tract = tract_format(args.tracts, args.aparc)

    streamlines = [t for t in tract]
    f_streamlines = []
    for sl in streamlines:
        # Avoid streamlines having only one point, as they crash the
        # Dipy connectivity matrix function.
        if sl.shape[0] > 1:
            f_streamlines.append(sl)

    # Compute affine
    affine = compute_affine_for_dipy_functions(args.aparc, args.tracts)

    # Compute matrix
    M = dpu.connectivity_matrix(f_streamlines,
                                reduced_labels,
                                affine=affine,
                                symmetric=True,
                                return_mapping=False,
                                mapping_as_streamlines=False)
    # Remove background connectivity
    M = M[1:, 1:]

    # Save needed files
    np.save(args.out_matrix, np.array(M))

    # Compute the mapping between row numbers, labels and ids.
    sorted_lut = sorted(labels_lut)
    row_name_map = {}
    # Skip first for BG
    for id, lab_val in enumerate(sorted_lut[1:]):
        # Find the associated Freesurfer id
        free_name = requested_labels_mapping[lab_val]['free_name']
        lut_name = requested_labels_mapping[lab_val]['lut_name']

        # Find the mean y position of the label to be able to spatially sort.
        positions = np.where(full_labels == lab_val)
        mean_y = np.mean(positions[1])

        row_name_map[id] = {
            'free_name': free_name,
            'lut_name': lut_name,
            'free_label': lab_val,
            'mean_y_pos': mean_y
        }

    with open(args.out_row_map, 'w') as f:
        pickle.dump(row_name_map, f)
예제 #36
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    print( 'Fiber registration...')

    #####################################
    # Checking if the files exist       #
    #####################################
    for myFile in [args.init_anat_name, args.final_anat_name, args.matrix_name, args.tracts_name]:
        if not os.path.isfile(myFile):
            parser.error('"{0}" must be a file!'.format(myFile))

    if os.path.exists(args.output_name):
        print (args.output_name, " already exist and will be overwritten.")

    #####################################
    # Loading anatomies                 #
    # Applying voxel sizes              #
    #####################################
    init_anat_img = nib.load(args.init_anat_name)
    init_shape = np.array(init_anat_img.shape) * np.array(init_anat_img.get_header()['pixdim'][1:4])

    #####################################
    # Loading tracts                    #
    #####################################
    tracts_format = tc.detect_format(args.tracts_name)
    tracts_file = tracts_format(args.tracts_name, anatFile=args.init_anat_name )
    hdr = tracts_file.hdr
    tracts_init = [i for i in tracts_file]

    nb_tracts = len(tracts_init)

    #####################################
    # Loading the matrix                #
    #####################################
    Big_matrix = np.loadtxt(args.matrix_name)
    Rot= Big_matrix[0:3,0:3]
    Trans = Big_matrix[0:3,3]

    #####################################
    # Registration                      #
    #####################################
    #tracts_final = register(init_shape, nb_tracts, tracts_init, Rot, Trans)
    tracts_final = []
    for this_tract in range(nb_tracts):
        nb_points_in_tract = len(tracts_init[this_tract])
        tracts_final.append(np.zeros((nb_points_in_tract,3)))
        for this_point in range(nb_points_in_tract):
            indice = tracts_init[this_tract][this_point]
            indice_flip = np.asarray([init_shape[0] - indice[0], indice[1], indice[2]]) #flipping
            indice_registered = np.dot(Rot, [indice_flip[0], indice_flip[1], indice_flip[2]]) + Trans
            indice_registered = np.asarray([init_shape[0]-indice_registered[0], indice_registered[1], indice_registered[2]]) #flipping back
            tracts_final[this_tract][this_point] = indice_registered

    #####################################
    # Saving                             #
    #####################################
    out_format = tc.detect_format(args.output_name)
    out_tracts = out_format.create(args.output_name, hdr, anatFile=args.final_anat_name)
    out_tracts += [t for t in tracts_final] # a tester

    out_tracts.close()
    print('...Done')
def main():

    parser = buildArgsParser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if not args.not_all:
        if not args.tdi_file:
            args.tdi_file = 'tdi.nii.gz'
        if not args.apm_file:
            args.apm_file = 'apm.nii.gz'
        if not args.cdec_file:
            args.cdec_file = 'cdec.nii.gz'

    arglist = [args.tdi_file, args.apm_file, args.cdec_file]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')
    for out in arglist:
        if os.path.isfile(out):
            if args.overwrite:
                logging.info('Overwriting "{0}".'.format(out))
            else:
                parser.error(
                    '"{0}" already exists! Use -f to overwrite it.'.format(
                        out))

    ref = nib.load(args.ref_file)
    ref_res = ref.get_header()['pixdim'][1]
    up_factor = ref_res / args.res
    data_shape = np.array(ref.shape) * up_factor
    data_shape = list(data_shape.astype('int32'))

    logging.info("Reference resolution: " + str(ref_res))
    logging.info("Reference shape: " + str(ref.shape))
    logging.info("Target resolution: " + str(args.res))
    logging.info("Target shape: " + str(data_shape))

    cdec_map = np.zeros(data_shape + [3], dtype='float32')
    tdi_map = np.zeros(data_shape, dtype='float32')
    apm_map = np.zeros(data_shape, dtype='float32')

    tract_format = tc.detect_format(args.tract_file)
    tract = tract_format(args.tract_file)
    streamlines = [i for i in tract]
    streamlines_np = np.array(streamlines, dtype=np.object)

    for i, streamline in enumerate(streamlines_np):
        if not i % 10000:
            logging.info(str(i) + "/" + str(streamlines_np.shape[0]))

        streamline_length = length(streamline)
        dec_vec = np.array(streamline[0] - streamline[-1])
        dec_vec_norm = np.linalg.norm(dec_vec)
        if dec_vec_norm > 0:
            dec_vec = np.abs(dec_vec / dec_vec_norm)
        else:
            dec_vec[0] = dec_vec[1] = dec_vec[2] = 0

        for point in streamline:
            pos = point / args.res
            ind = tuple(pos.astype('int32'))
            if (ind[0] >= 0 and ind[0] < data_shape[0] and ind[1] >= 0
                    and ind[1] < data_shape[1] and ind[2] >= 0
                    and ind[2] < data_shape[2]):
                tdi_map[ind] += 1
                apm_map[ind] += streamline_length
                cdec_map[ind] += dec_vec

    # devide the sum of streamline length by the streamline density
    apm_map /= tdi_map

    # normalise the cdec map
    cdec_norm = np.sqrt((cdec_map * cdec_map).sum(axis=3))
    cdec_map = cdec_map / cdec_norm.reshape(list(cdec_norm.shape) + [1]) * 255

    affine = ref.get_affine()
    affine[0][0] = affine[1][1] = affine[2][2] = args.res

    if args.tdi_file:
        tdi_img = nib.Nifti1Image(tdi_map, affine)
        tdi_img.get_header().set_zooms([args.res, args.res, args.res])
        tdi_img.get_header().set_qform(ref.get_header().get_qform())
        tdi_img.get_header().set_sform(ref.get_header().get_sform())
        tdi_img.to_filename(args.tdi_file)

    if args.apm_file:
        apm_img = nib.Nifti1Image(apm_map, affine)
        apm_img.get_header().set_zooms([args.res, args.res, args.res])
        apm_img.get_header().set_qform(ref.get_header().get_qform())
        apm_img.get_header().set_sform(ref.get_header().get_sform())
        apm_img.to_filename(args.apm_file)

    if args.cdec_file:
        cdec_img = nib.Nifti1Image(cdec_map.astype('uint8'), affine)
        cdec_img.get_header().set_zooms([args.res, args.res, args.res, 1])
        cdec_img.get_header().set_qform(ref.get_header().get_qform())
        cdec_img.get_header().set_sform(ref.get_header().get_sform())
        cdec_img.to_filename(args.cdec_file)
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    if not os.path.isfile(args.tracts):
        parser.error("Tracts file: {0} does not exist.".format(args.tracts))

    if not os.path.isfile(args.org_aparc):
        parser.error("Original label file: {0} does not exist.".format(
            args.org_aparc))

    if not os.path.isfile(args.dilated_aparc):
        parser.error("Dilated label file: {0} does not exist.".format(
            args.dilated_aparc))

    if not os.path.isfile(args.subcortical_labels):
        parser.error("Requested region file: {0} does not exist.".format(
            args.subcortical_labels))

    if not os.path.isfile(args.lut):
        parser.error("Freesurfer LUT file: {0} does not exist.".format(
            args.lut))

    if not os.path.isfile(args.faimage):
        parser.error("FA Image file: {0} does not exist.".format(args.faimage))

    if not os.path.isfile(args.mdimage):
        parser.error("MD Image file: {0} does not exist.".format(args.mdimage))

    # Validate that tracts can be processed
    if not validate_coordinates(
            args.org_aparc, args.tracts, nifti_compliant=True):
        parser.error("The tracts file contains points that are invalid.\n" +
                     "Use the remove_invalid_coordinates.py script to clean.")

    # Load label images
    org_labels_img = nib.load(args.org_aparc)
    org_labels_data = org_labels_img.get_data().astype('int')

    dilated_labels_img = nib.load(args.dilated_aparc)
    dilated_labels_data = dilated_labels_img.get_data().astype('int')

    # Load fibers
    tract_format = tc.detect_format(args.tracts)
    tract = tract_format(args.tracts, args.org_aparc)
    affine = compute_affine_for_dipy_functions(args.org_aparc, args.tracts)

    #load FA and MD image
    fa_img = nib.load(args.faimage)
    fa_data = fa_img.get_data()

    md_img = nib.load(args.mdimage)
    md_data = md_img.get_data()

    # ========= processing streamlines =================
    fiberlen_range = np.asarray([args.minlen, args.maxlen])

    streamlines = [t for t in tract]
    print "Subjeect " + args.sub_id + " has " + str(
        len(streamlines)) + " streamlines."

    f_streamlines = []  #filtered streamlines
    lenrecord = []
    idx = 0
    for sl in streamlines:
        # Avoid streamlines having only one point, as they crash the
        # Dipy connectivity matrix function.
        if sl.shape[0] > 1:
            flen = length(sl)
            # get fibers having length between 20mm and 200mm
            if (flen > fiberlen_range[0]) & (flen < fiberlen_range[1]):
                f_streamlines.append(sl)
                lenrecord.append(flen)
                idx = idx + 1

    print "Subject " + args.sub_id + " has " + str(
        idx - 1) + " streamlines after filtering."

    # ============= process the parcellation =====================

    # Compute the mapping from label name to label id
    label_id_mapping = compute_labels_map(args.lut)

    # Find which labels were requested by the user.
    requested_labels_mapping = compute_requested_labels(
        args.subcortical_labels, label_id_mapping)

    # Increase aparc_filtered_labels with subcortical regions
    # 17 LH_Hippocampus
    # 53 RH_Hippocampus
    # 11 LH_Caudate
    # 50 RH_Caudate
    # 12 LH_Putamen
    # 51 RH_Putamen
    # 13 LH_Pallidum
    # 52 RH_Pallidum
    # 18 LH_Amygdala
    # 54 RH_Amygdala
    # 26 LH_Accumbens
    # 58 RH_Accumbens
    # 10 LH_Thalamus-Proper
    # 49 RH_Thalamus-Proper
    # 4 LH_Lateral-Ventricle
    # 43 RH_Lateral-Ventricle
    # 8 LH_Cerebellum-Cortex
    # 47 RH_Cerebellum-Cortex
    #
    # 16 _Brain-Stem (# 7,8 LH_Cerebellum) (# 41 RH_Cerebellum)

    sub_cortical_labels = [
        17, 53, 11, 50, 12, 51, 13, 52, 18, 54, 26, 58, 10, 49, 4, 43, 8, 47
    ]  # 16
    Brain_Stem_cerebellum = [16]  #1

    aparc_filtered_labels = dilated_labels_data
    for label_val in requested_labels_mapping:
        if sum(sum(sum(org_labels_data == label_val))) == 0:
            print label_val
            print requested_labels_mapping[label_val]

        aparc_filtered_labels[org_labels_data == label_val] = label_val

    for brain_stem_id in Brain_Stem_cerebellum:
        if sum(sum(sum(org_labels_data == brain_stem_id))) == 0:
            print 'no labels of '
            print brain_stem_id
        aparc_filtered_labels[
            org_labels_data ==
            brain_stem_id] = 99  # let the brain stem's label be 30

    # Reduce the range of labels to avoid a sparse matrix,
    # because the ids of labels can range from 0 to the 12000's.
    reduced_dilated_labels, labels_lut = dpu.reduce_labels(
        aparc_filtered_labels)

    #dilated_labels_fname = args.sub_id + "_" + args.pre + "_dilated_allbrain_labels.nii.gz"
    #dilated_labels_img = nib.Nifti1Image(aparc_filtered_labels, org_labels_img.get_affine(),org_labels_img.get_header())
    #nib.save(dilated_labels_img,dilated_labels_fname)
    #print args.sub_id + 'dilated labels have saved'
    #pdb.set_trace()

    # Compute connectivity matrix and extract the fibers
    M, grouping = nconnectivity_matrix(f_streamlines,
                                       reduced_dilated_labels,
                                       fiberlen_range,
                                       args.cnpoint,
                                       affine=affine,
                                       symmetric=True,
                                       return_mapping=True,
                                       mapping_as_streamlines=True,
                                       keepfiberinroi=True)

    Msize = len(M)
    CM_before_outlierremove = M[1:, 1:]
    nstream_bf = np.sum(CM_before_outlierremove)
    print args.sub_id + ' ' + str(
        nstream_bf
    ) + ' streamlines in the connectivity matrix before outlier removal.'

    # ===================== process the streamlines =============
    print 'Processing streamlines to remove outliers ..............'

    outlier_para = 3
    average_thrd = 8

    M_after_ourlierremove = np.zeros((Msize, Msize))
    # downsample streamlines
    cell_streamlines = []
    cell_id = []
    for i in range(1, Msize):
        for j in range(i + 1, Msize):
            tmp_streamlines = grouping[i, j]
            tmp_streamlines = list(tmp_streamlines)
            # downsample
            tmp_streamlines_downsampled = [
                downsample(s, 100) for s in tmp_streamlines
            ]
            # remove outliers, we need to rewrite the QuickBundle method to speed up this process

            qb = QuickBundles(threshold=average_thrd)
            clusters = qb.cluster(tmp_streamlines_downsampled)
            outlier_clusters = clusters < outlier_para  # small clusters
            nonoutlier_clusters = clusters[np.logical_not(outlier_clusters)]

            tmp_nonoutlier_index = []
            for tmp_cluster in nonoutlier_clusters:
                tmp_nonoutlier_index = tmp_nonoutlier_index + tmp_cluster.indices

            clean_streamline_downsampled = [
                tmp_streamlines_downsampled[ind]
                for ind in tmp_nonoutlier_index
            ]
            cell_streamlines.append(clean_streamline_downsampled)
            cell_id.append([i, j])
            M_after_ourlierremove[i, j] = len(clean_streamline_downsampled)

    CM_after_ourlierremove = M_after_ourlierremove[1:, 1:]
    nstream_bf = np.sum(CM_after_ourlierremove)
    print args.sub_id + ' ' + str(
        nstream_bf
    ) + ' streamlines in the connectivity matrix after outlier removal.'

    #===================== save the data =======================

    if (args.saving_indicator == 1):  # save the whole brain connectivity

        cmCountMatrix_fname = args.sub_id + "_" + args.pre + "_allbrain" + "_cm_count_raw.mat"
        cmCountMatrix_processed_fname = args.sub_id + "_" + args.pre + "_allbrain" + "_cm_count_processed.mat"
        cmStreamlineMatrix_fname = args.sub_id + "_" + args.pre + "_allbrain" + "_cm_streamlines.mat"
        reduced_dilated_labels_fname = args.sub_id + "_" + args.pre + "_allbrain" + "_reduced_dilated_labels.nii.gz"
        RoiInfo_fname = args.sub_id + "_" + args.pre + "_allbrain_RoiInfo.mat"

        # save the raw count matrix
        CM = M[1:, 1:]
        sio.savemat(cmCountMatrix_fname, {'cm': CM})
        sio.savemat(cmCountMatrix_processed_fname,
                    {'cm': CM_after_ourlierremove})

        # save the streamline matrix
        sio.savemat(cmStreamlineMatrix_fname, {'slines': cell_streamlines})
        sio.savemat(RoiInfo_fname, {'ROIinfo': cell_id})
        print args.sub_id + 'cell_streamlines.mat, ROIinfo.mat has been saved'

        filtered_labels_img = nib.Nifti1Image(aparc_filtered_labels,
                                              org_labels_img.get_affine(),
                                              org_labels_img.get_header())
        nib.save(filtered_labels_img, reduced_dilated_labels_fname)
        print args.sub_id + 'all brain dilated labels have saved'

        # ===================== process the streamlines and extract features =============
        cm_fa_curve = fa_extraction_use_cellinput(cell_streamlines,
                                                  cell_id,
                                                  fa_data,
                                                  Msize,
                                                  affine=affine)
        (tmp_cm_fa_mean, tmp_cm_fa_max,
         cm_count) = fa_mean_extraction(cm_fa_curve, Msize)

        # extract MD values along the streamlines
        cm_md_curve = fa_extraction_use_cellinput(cell_streamlines,
                                                  cell_id,
                                                  md_data,
                                                  Msize,
                                                  affine=affine)
        (tmp_cm_md_mean, tmp_cm_md_max,
         testcm) = fa_mean_extraction(cm_md_curve, Msize)

        # connected surface area
        # extract the connective volume ratio
        (tmp_cm_volumn, tmp_cm_volumn_ratio) = rois_connectedvol_cellinput(
            reduced_dilated_labels,
            Msize,
            cell_streamlines,
            cell_id,
            affine=affine)

        # fiber length
        tmp_connectcm_len = rois_fiberlen_cellinput(Msize, cell_streamlines)

        # save cm features
        cm_md_mean = tmp_cm_md_mean[1:, 1:]
        cm_md_max = tmp_cm_md_max[1:, 1:]

        cm_fa_mean = tmp_cm_fa_mean[1:, 1:]
        cm_fa_max = tmp_cm_fa_max[1:, 1:]

        cm_volumn = tmp_cm_volumn[1:, 1:]
        cm_volumn_ratio = tmp_cm_volumn_ratio[1:, 1:]

        connectcm_len = tmp_connectcm_len[1:, 1:]

        sio.savemat(args.pre + "_allbrain" + "_cm_processed_mdmean_100.mat",
                    {'cm_mdmean': cm_md_mean})
        sio.savemat(args.pre + "_allbrain" + "_cm_processed_mdmax_100.mat",
                    {'cm_mdmax': cm_md_max})
        sio.savemat(args.pre + "_allbrain" + "_cm_processed_famean_100.mat",
                    {'cm_famean': cm_fa_mean})
        sio.savemat(args.pre + "_allbrain" + "_cm_processed_famax_100.mat",
                    {'cm_famax': cm_fa_max})
        sio.savemat(args.pre + "_allbrain" + "_cm_processed_volumn_100.mat",
                    {'cm_volumn': cm_volumn})
        sio.savemat(
            args.pre + "_allbrain" + "_cm_processed_volumn_ratio_100.mat",
            {'cm_volumn_ratio': cm_volumn_ratio})
        sio.savemat(args.pre + "_allbrain" + "_cm_processed_fiberlen_100.mat",
                    {'cm_len': connectcm_len})

        # save the streamline matrix
        cell_fa = []
        for i in range(1, Msize):
            for j in range(i + 1, Msize):
                tmp_fa = cm_fa_curve[i, j]
                tmp_fa = list(tmp_fa)
                cell_fa.append(tmp_fa)

        sio.savemat(args.pre + "_allbrain" + "_cm_processed_sfa_100.mat",
                    {'sfa': cell_fa})
        print args.pre + '_allbrain" + "_cm_processed_sfa_100.mat' + ' has been saved'

        cell_md = []
        for i in range(1, Msize):
            for j in range(i + 1, Msize):
                tmp_md = cm_md_curve[i, j]
                tmp_md = list(tmp_md)
                cell_md.append(tmp_md)

        sio.savemat(args.pre + "_allbrain" + "_cm_processed_smd_100.mat",
                    {'smd': cell_md})
        print args.pre + '_allbrain" + "_cm_processed_smd_100.mat' + ' has been saved'

    if (
            args.saving_indicator == 0
    ):  # save the part of the connection: connection between subcortical region

        Nsubcortical_reg = len(sub_cortical_labels) + 1  # should be 19

        cmCountMatrix_fname = args.sub_id + "_" + args.pre + "_partbrain_subcort" + "_cm_count_raw.mat"
        cmCountMatrix_processed_fname = args.sub_id + "_" + args.pre + "_partbrain_subcort" + "_cm_count_processed.mat"
        cmStreamlineMatrix_fname = args.sub_id + "_" + args.pre + "_partbrain_subcort" + "_cm_streamlines.mat"
        reduced_dilated_labels_fname = args.sub_id + "_" + args.pre + "_partbrain_subcort" + "_reduced_dilated_labels.nii.gz"
        subcortical_RoiInfo_fname = args.sub_id + "_" + args.pre + "_partbrain_subcort_RoiInfo.mat"

        # save the raw count matrix
        CM = M[1:, 1:]
        sio.savemat(cmCountMatrix_fname, {'cm': CM})
        sio.savemat(cmCountMatrix_processed_fname,
                    {'cm': CM_after_ourlierremove})

        filtered_labels_img = nib.Nifti1Image(aparc_filtered_labels,
                                              org_labels_img.get_affine(),
                                              org_labels_img.get_header())
        nib.save(filtered_labels_img, reduced_dilated_labels_fname)
        print args.sub_id + ' all brain dilated labels have saved'

        # ===================== process the streamlines and extract features =============
        cm_fa_curve = fa_extraction_use_cellinput(cell_streamlines,
                                                  cell_id,
                                                  fa_data,
                                                  Msize,
                                                  affine=affine)
        (tmp_cm_fa_mean, tmp_cm_fa_max,
         cm_count) = fa_mean_extraction(cm_fa_curve, Msize)

        # extract MD values along the streamlines
        cm_md_curve = fa_extraction_use_cellinput(cell_streamlines,
                                                  cell_id,
                                                  md_data,
                                                  Msize,
                                                  affine=affine)
        (tmp_cm_md_mean, tmp_cm_md_max,
         testcm) = fa_mean_extraction(cm_md_curve, Msize)

        # connected surface area
        # extract the connective volume ratio
        (tmp_cm_volumn, tmp_cm_volumn_ratio) = rois_connectedvol_cellinput(
            reduced_dilated_labels,
            Msize,
            cell_streamlines,
            cell_id,
            affine=affine)

        # fiber length
        tmp_connectcm_len = rois_fiberlen_cellinput(Msize, cell_streamlines)

        # save cm features
        cm_md_mean = tmp_cm_md_mean[1:, 1:]
        cm_md_max = tmp_cm_md_max[1:, 1:]

        cm_fa_mean = tmp_cm_fa_mean[1:, 1:]
        cm_fa_max = tmp_cm_fa_max[1:, 1:]

        cm_volumn = tmp_cm_volumn[1:, 1:]
        cm_volumn_ratio = tmp_cm_volumn_ratio[1:, 1:]

        connectcm_len = tmp_connectcm_len[1:, 1:]

        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_mdmean_100.mat",
            {'cm_mdmean': cm_md_mean})
        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_mdmax_100.mat",
            {'cm_mdmax': cm_md_max})
        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_famean_100.mat",
            {'cm_famean': cm_fa_mean})
        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_famax_100.mat",
            {'cm_famax': cm_fa_max})
        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_volumn_100.mat",
            {'cm_volumn': cm_volumn})
        sio.savemat(
            args.pre + "_partbrain_subcort" +
            "_cm_processed_volumn_ratio_100.mat",
            {'cm_volumn_ratio': cm_volumn_ratio})
        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_fiberlen_100.mat",
            {'cm_len': connectcm_len})

        # save the streamline matrix
        cell_fa = []
        cell_id = []
        for i in range(1, Nsubcortical_reg):
            for j in range(i + 1, Msize):
                tmp_fa = cm_fa_curve[i, j]
                tmp_fa = list(tmp_fa)
                cell_fa.append(tmp_fa)
                cell_id.append([i, j])

        sio.savemat(
            args.pre + "_partbrain_subcort" + "_cm_processed_sfa_100.mat",
            {'sfa': cell_fa})
        print args.pre + '_partbrain_subcort' + '_cm_processed_sfa_100.mat' + 'has been saved.'

        cell_md = []
        for i in range(1, Nsubcortical_reg):
            for j in range(i + 1, Msize):
                tmp_md = cm_md_curve[i, j]
                tmp_md = list(tmp_md)
                cell_md.append(tmp_md)

        sio.savemat(args.pre + "_partbrain" + "_cm_processed_smd_100.mat",
                    {'smd': cell_md})
        print args.pre + '_partbrain' + '_cm_processed_smd_100.mat' + 'has been saved.'

        # save the streamline matrix
        subcortical_cell_streamlines = []
        cell_id = []
        idx = 0
        for i in range(1, Nsubcortical_reg):
            for j in range(i + 1, Msize):
                tmp_sls = cell_streamlines[idx]
                idx = idx + 1
                subcortical_cell_streamlines.append(tmp_sls)
                cell_id.append([i, j])

        sio.savemat(cmStreamlineMatrix_fname,
                    {'slines': subcortical_cell_streamlines})
        sio.savemat(subcortical_RoiInfo_fname, {'ROIinfo': cell_id})
        print cmStreamlineMatrix_fname + ' has been saved'