def main():
    parser = buildArgsParser()
    args = parser.parse_args()
    param = {}

    if args.pft_theta is None and args.pft_curvature is None:
        args.pft_theta = 20

    if not np.any([args.nt, args.npv, args.ns]):
        args.npv = 1

    if args.theta is not None:
        theta = gm.math.radians(args.theta)
    elif args.curvature > 0:
        theta = get_max_angle_from_curvature(args.curvature, args.step_size)
    elif args.algo == 'prob':
        theta = gm.math.radians(20)
    else:
        theta = gm.math.radians(45)

    if args.pft_curvature is not None:
        pft_theta = get_max_angle_from_curvature(args.pft_curvature, args.step_size)
    else:
        pft_theta = gm.math.radians(args.pft_theta)

    if args.mask_interp == 'nn':
        mask_interpolation = 'nearest'
    elif args.mask_interp == 'tl':
        mask_interpolation = 'trilinear'
    else:
        parser.error("--mask_interp has wrong value. See the help (-h).")
        return

    if args.field_interp == 'nn':
        field_interpolation = 'nearest'
    elif args.field_interp == 'tl':
        field_interpolation = 'trilinear'
    else:
        parser.error("--sh_interp has wrong value. See the help (-h).")
        return

    param['random'] = args.random
    param['skip'] = args.skip
    param['algo'] = args.algo
    param['mask_interp'] = mask_interpolation
    param['field_interp'] = field_interpolation
    param['theta'] = theta
    param['sf_threshold'] = args.sf_threshold
    param['pft_sf_threshold'] = args.pft_sf_threshold if args.pft_sf_threshold is not None else args.sf_threshold
    param['sf_threshold_init'] = args.sf_threshold_init
    param['step_size'] = args.step_size
    param['max_length'] = args.max_length
    param['min_length'] = args.min_length
    param['is_single_direction'] = args.is_single_direction
    param['nbr_seeds'] = args.nt if args.nt is not None else 0
    param['nbr_seeds_voxel'] = args.npv if args.npv is not None else 0
    param['nbr_streamlines'] = args.ns if args.ns is not None else 0
    param['max_no_dir'] = int(math.ceil(args.maxL_no_dir / param['step_size']))
    param['is_all'] = args.is_all
    param['is_act'] = args.is_act
    param['theta_pft'] = pft_theta
    if args.not_is_pft:
        param['nbr_particles'] = 0
        param['back_tracking'] = 0
        param['front_tracking'] = 0
    else:
        param['nbr_particles'] = args.nbr_particles
        param['back_tracking'] = int(
            math.ceil(args.back_tracking / args.step_size))
        param['front_tracking'] = int(
            math.ceil(args.front_tracking / args.step_size))
    param['nbr_iter'] = param['back_tracking'] + param['front_tracking']
    param['mmap_mode'] = None if args.isLoadData else 'r'

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    logging.debug('Tractography parameters:\n{0}'.format(param))

    if os.path.isfile(args.output_file):
        if args.isForce:
            logging.info('Overwriting "{0}".'.format(args.output_file))
        else:
            parser.error(
                '"{0}" already exists! Use -f to overwrite it.'
                .format(args.output_file))

    include_dataset = Dataset(
        nib.load(args.map_include_file), param['mask_interp'])
    exclude_dataset = Dataset(
        nib.load(args.map_exclude_file), param['mask_interp'])
    if param['is_act']:
        mask = ACT(include_dataset, exclude_dataset,
                   param['step_size'] / include_dataset.size[0])
    else:
        mask = CMC(include_dataset, exclude_dataset,
                   param['step_size'] / include_dataset.size[0])

    dataset = Dataset(nib.load(args.sh_file), param['field_interp'])
    field = SphericalHarmonicField(
        dataset, args.basis, param['sf_threshold'],
        param['sf_threshold_init'], param['theta'])

    if args.algo == 'det':
        tracker = deterministicMaximaTracker(field, param['step_size'])
    elif args.algo == 'prob':
        tracker = probabilisticTracker(field, param['step_size'])
    else:
        parser.error("--algo has wrong value. See the help (-h).")
        return

    pft_field = SphericalHarmonicField(
        dataset, args.basis, param['pft_sf_threshold'],
        param['sf_threshold_init'], param['theta_pft'])

    pft_tracker = probabilisticTracker(pft_field, param['step_size'])
    
    # ADD Seed input
    # modify ESO
    nib_mask = nib.load(args.map_include_file)
    seed_points = np.load(args.seed_points)
    seed_dirs = np.load(args.seed_dir)
    rotation = nib_mask.get_affine()[:3,:3]
    inv_rotation = np.linalg.inv(rotation)
    translation = nib_mask.get_affine()[:3,3]
    scale = np.array(nib_mask.get_header().get_zooms())
    voxel_space = nib.aff2axcodes(nib_mask.get_affine())
    
    print voxel_space
    # seed points transfo
    # LPS -> voxel_space
    print scale
    if voxel_space[0] != 'L':
        print "flip X"
        seed_points[:,0] = -seed_points[:,0]
    if voxel_space[1] != 'P':
        print "flip Y"
        seed_points[:,1] = -seed_points[:,1]
    if voxel_space[2] != 'S':
        print "flip Z"
        seed_points[:,2] = -seed_points[:,2]
    
    # other transfo
    seed_points = seed_points - translation
    seed_points = seed_points.dot(inv_rotation)
    seed_points = seed_points * scale
    
    # seed dir transfo
    seed_dirs[:,0:2] = -seed_dirs[:,0:2]
    seed_dirs = seed_dirs.dot(inv_rotation)
    seed_dirs = seed_dirs * scale
    
    if args.inv_seed_dir:
        seed_dirs = seed_dirs * -1.0
    
    # Compute tractography
    nb_seeds = len(seed_dirs)
    if args.test is not None and args.test < nb_seeds:
        nb_seeds = args.test
    # end modify ESO
    
    
    # tracker to modify
    # modify ESO
    start = time.time()
    streamlines = []
    for i in range(nb_seeds):
        s = generate_streamline(tracker, mask, seed_points[i], seed_dirs[i], pft_tracker=pft_tracker, param=param)
        streamlines.append(s)
        stdout.write("\r %d%%" % (i*101//nb_seeds))
        stdout.flush()
    
    stdout.write("\n done")
    stdout.flush()
    stop = time.time()
    # end modify ESO

    
    # ADD save fiber output
    # modify ESO
    for i in range(len(streamlines)):
        streamlines[i] = streamlines[i] / scale
        streamlines[i] = streamlines[i].dot(rotation)
        streamlines[i] = streamlines[i] + translation
        # voxel_space -> LPS
        if voxel_space[0] != 'L':
            streamlines[i][:,0] = -streamlines[i][:,0]
        if voxel_space[1] != 'P':
            streamlines[i][:,1] = -streamlines[i][:,1]
        if voxel_space[2] != 'S':
            streamlines[i][:,2] = -streamlines[i][:,2]
    
    lines_polydata = lines_to_vtk_polydata(streamlines, None, np.float32)
    save_polydata(lines_polydata, args.output_file , True)
    # end modify ESO

    lengths = [len(s) for s in streamlines]
    if nb_seeds > 0:
        ave_length = (sum(lengths) / nb_seeds) * param['step_size']
    else:
        ave_length = 0
    
    str_ave_length = "%.2f" % ave_length
    str_time = "%.2f" % (stop - start)
    print(str(nb_seeds) + " streamlines, with an average length of " +
          str_ave_length + " mm, done in " + str_time + " seconds.")
Пример #2
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()
    param = {}

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.outputTQ:
        filename_parts = os.path.splitext(args.output_file)
        output_filename = filename_parts[0] + '.tq' + filename_parts[1]
    else:
        output_filename = args.output_file

    out_format = tc.detect_format(output_filename)
    if out_format not in [tc.formats.trk.TRK, tc.formats.tck.TCK]:
        parser.error("Invalid output streamline file format (must be trk or " +
                     "tck): {0}".format(output_filename))
        return

    if os.path.isfile(output_filename):
        if args.isForce:
            logging.debug('Overwriting "{0}".'.format(output_filename))
        else:
            parser.error(
                '"{0}" already exists! Use -f to overwrite it.'.format(
                    output_filename))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {0}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={0}mm, maxL={1}mm).'.format(
                args.min_length, args.max_length))

    if not np.any([args.nt, args.npv, args.ns]):
        args.npv = 1

    if args.theta is not None:
        theta = gm.math.radians(args.theta)
    elif args.curvature > 0:
        theta = get_max_angle_from_curvature(args.curvature, args.step_size)
    else:
        theta = gm.math.radians(45)

    if args.mask_interp == 'nn':
        mask_interpolation = 'nearest'
    elif args.mask_interp == 'tl':
        mask_interpolation = 'trilinear'
    else:
        parser.error("--mask_interp has wrong value. See the help (-h).")
        return

    param['random'] = args.random
    param['skip'] = args.skip
    param['algo'] = args.algo
    param['mask_interp'] = mask_interpolation
    param['field_interp'] = 'nearest'
    param['theta'] = theta
    param['sf_threshold'] = args.sf_threshold
    param['sf_threshold_init'] = args.sf_threshold_init
    param['step_size'] = args.step_size
    param['rk_order'] = args.rk_order
    param['max_length'] = args.max_length
    param['min_length'] = args.min_length
    param['max_nbr_pts'] = int(param['max_length'] / param['step_size'])
    param['min_nbr_pts'] = int(param['min_length'] / param['step_size']) + 1
    param['is_single_direction'] = args.is_single_direction
    param['nbr_seeds'] = args.nt if args.nt is not None else 0
    param['nbr_seeds_voxel'] = args.npv if args.npv is not None else 0
    param['nbr_streamlines'] = args.ns if args.ns is not None else 0
    param['max_no_dir'] = int(math.ceil(args.maxL_no_dir / param['step_size']))
    param['is_all'] = False
    param['is_keep_single_pts'] = False
    # r+ is necessary for interpolation function in cython who
    # need read/write right
    param['mmap_mode'] = None if args.isLoadData else 'r+'

    logging.debug('Tractography parameters:\n{0}'.format(param))

    seed_img = nib.load(args.seed_file)
    seed = Seed(seed_img)
    if args.npv:
        param['nbr_seeds'] = len(seed.seeds) * param['nbr_seeds_voxel']
        param['skip'] = len(seed.seeds) * param['skip']
    if len(seed.seeds) == 0:
        parser.error('"{0}" does not have voxels value > 0.'.format(
            args.seed_file))

    mask = BinaryMask(Dataset(nib.load(args.mask_file), param['mask_interp']))

    dataset = Dataset(nib.load(args.peaks_file), param['field_interp'])
    field = MaximaField(dataset, param['sf_threshold'],
                        param['sf_threshold_init'], param['theta'])

    if args.algo == 'det':
        tracker = deterministicMaximaTracker(field, param)
    elif args.algo == 'prob':
        tracker = probabilisticTracker(field, param)
    else:
        parser.error("--algo has wrong value. See the help (-h).")
        return

    start = time.time()
    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warn(
                'You are using an error rate of {}.\n'.format(args.compress) +
                'We recommend setting it between 0.001 and 1.\n' +
                '0.001 will do almost nothing to the tracts while ' +
                '1 will higly compress/linearize the tracts')

        streamlines = track(tracker,
                            mask,
                            seed,
                            param,
                            compress=True,
                            compression_error_threshold=args.compress,
                            nbr_processes=args.nbr_processes,
                            pft_tracker=None)
    else:
        streamlines = track(tracker,
                            mask,
                            seed,
                            param,
                            nbr_processes=args.nbr_processes,
                            pft_tracker=None)

    if args.outputTQ:
        save_streamlines_tractquerier(streamlines, args.seed_file,
                                      output_filename)
    else:
        save_streamlines_fibernavigator(streamlines, args.seed_file,
                                        output_filename)

    str_ave_length = "%.2f" % compute_average_streamlines_length(streamlines)
    str_time = "%.2f" % (time.time() - start)
    logging.debug(
        str(len(streamlines)) + " streamlines, with an average " +
        "length of " + str_ave_length + " mm, done in " + str_time +
        " seconds.")
def main():
    np.random.seed(int(time.time()))
    parser = buildArgsParser()
    args = parser.parse_args()

    param = {}
    
    if args.algo not in ["det", "prob"]:
        parser.error("--algo has wrong value. See the help (-h).")
    
    if args.basis not in ["mrtrix", "dipy", "fibernav"]:
        parser.error("--basis has wrong value. See the help (-h).")
    
    #if np.all([args.nt is None, args.npv is None, args.ns is None]):
    #    args.npv = 1
    
    if args.theta is not None:
        theta = gm.math.radians(args.theta)
    elif args.curvature > 0:
        theta = get_max_angle_from_curvature(args.curvature, args.step_size)
    elif args.algo == 'prob':
        theta = gm.math.radians(20)
    else:
        theta = gm.math.radians(45)
    
    if args.mask_interp == 'nn':
        mask_interpolation = 'nearest'
    elif args.mask_interp == 'tl':
        mask_interpolation = 'trilinear'
    else:
        parser.error("--mask_interp has wrong value. See the help (-h).")
        return
    
    if args.field_interp == 'nn':
        field_interpolation = 'nearest'
    elif args.field_interp == 'tl':
        field_interpolation = 'trilinear'
    else:
        parser.error("--sh_interp has wrong value. See the help (-h).")
        return
    
    param['algo'] = args.algo
    param['mask_interp'] = mask_interpolation
    param['field_interp'] = field_interpolation
    param['theta'] = theta
    param['sf_threshold'] = args.sf_threshold
    param['sf_threshold_init'] = args.sf_threshold_init
    param['step_size'] = args.step_size
    param['max_length'] = args.max_length
    param['min_length'] = args.min_length
    param['is_single_direction'] = False
    param['nbr_seeds'] = 0
    param['nbr_seeds_voxel'] = 0
    param['nbr_streamlines'] = 0
    param['max_no_dir'] = int(math.ceil(args.maxL_no_dir / param['step_size']))
    param['is_all'] = False
    param['isVerbose'] = args.isVerbose
    
    if param['isVerbose']:
        logging.basicConfig(level=logging.DEBUG)
    
    if param['isVerbose']:
        logging.info('Tractography parameters:\n{0}'.format(param))
    
    if os.path.isfile(args.output_file):
        if args.isForce:
            logging.info('Overwriting "{0}".'.format(args.output_file))
        else:
            parser.error(
                '"{0}" already exists! Use -f to overwrite it.'
                .format(args.output_file))
    
    nib_mask = nib.load(args.mask_file)
    mask = BinaryMask(
        Dataset(nib_mask, param['mask_interp']))
    
    dataset = Dataset(nib.load(args.sh_file), param['field_interp'])
    field = SphericalHarmonicField(
        dataset, args.basis, param['sf_threshold'], param['sf_threshold_init'], param['theta'])
    
    if args.algo == 'det':
        tracker = deterministicMaximaTracker(field, param['step_size'])
    elif args.algo == 'prob':
        tracker = probabilisticTracker(field, param['step_size'])
    else:
        parser.error("--algo has wrong value. See the help (-h).")
        return
    
    start = time.time()
    
    # Etienne St-Onge
    #load and transfo *** todo test with rotation and scaling
    seed_points = np.load(args.seed_points)
    seed_dirs = np.load(args.seed_dir)
    rotation = nib_mask.get_affine()[:3,:3]
    inv_rotation = np.linalg.inv(rotation)
    translation = nib_mask.get_affine()[:3,3]
    scale = np.array(nib_mask.get_header().get_zooms())
    voxel_space = nib.aff2axcodes(nib_mask.get_affine())
    
    print voxel_space
    # seed points transfo
    # LPS -> voxel_space
    if voxel_space[0] != 'L':
        print "flip X"
        seed_points[:,0] = -seed_points[:,0]
    if voxel_space[1] != 'P':
        print "flip Y"
        seed_points[:,1] = -seed_points[:,1]
    if voxel_space[2] != 'S':
        print "flip Z"
        seed_points[:,2] = -seed_points[:,2]
    
    # other transfo
    seed_points = seed_points - translation
    seed_points = seed_points.dot(inv_rotation)
    seed_points = seed_points * scale
    
    # seed dir transfo
    seed_dirs[:,0:2] = -seed_dirs[:,0:2]
    seed_dirs = seed_dirs.dot(inv_rotation)
    seed_dirs = seed_dirs * scale
    
    if args.inv_seed_dir:
        seed_dirs = seed_dirs * -1.0
    
    # Compute tractography
    nb_seeds = len(seed_dirs)
    if args.test is not None and args.test < nb_seeds:
        nb_seeds = args.test
    
    print args.algo," nb seeds: ", nb_seeds
    
    streamlines = []
    for i in range(nb_seeds):
        s = generate_streamline(tracker, mask, seed_points[i], seed_dirs[i], pft_tracker=None, param=param)
        streamlines.append(s)
        
        stdout.write("\r %d%%" % (i*101//nb_seeds))
        stdout.flush()
    stdout.write("\n done")
    stdout.flush()
    
    # transform back
    for i in range(len(streamlines)):
        streamlines[i] = streamlines[i] / scale
        streamlines[i] = streamlines[i].dot(rotation)
        streamlines[i] = streamlines[i] + translation
        # voxel_space -> LPS
        if voxel_space[0] != 'L':
            streamlines[i][:,0] = -streamlines[i][:,0]
        if voxel_space[1] != 'P':
            streamlines[i][:,1] = -streamlines[i][:,1]
        if voxel_space[2] != 'S':
            streamlines[i][:,2] = -streamlines[i][:,2]
    
    lines_polydata = lines_to_vtk_polydata(streamlines, None, np.float32)
    save_polydata(lines_polydata, args.output_file , True)
    
    lengths = [len(s) for s in streamlines]
    if nb_seeds > 0:
        ave_length = (sum(lengths) / nb_seeds) * param['step_size']
    else:
        ave_length = 0
    
    str_ave_length = "%.2f" % ave_length
    str_time = "%.2f" % (time.time() - start)
    print(str(nb_seeds) + " streamlines, with an average length of " +
          str_ave_length + " mm, done in " + str_time + " seconds.")