Beispiel #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)
    check_tracts_same_format(parser, args.in_tractogram, args.out_tractogram)

    if args.error_rate < 0.001 or args.error_rate > 1:
        logging.warning(
            'You are using an error rate of {}.\n'
            'We recommend setting it between 0.001 and 1.\n'
            '0.001 will do almost nothing to the streamlines\n'
            'while 1 will highly compress/linearize the streamlines'.format(
                args.error_rate))

    in_tractogram = nib.streamlines.load(args.in_tractogram, lazy_load=True)
    compressed_streamlines = compress_streamlines_wrapper(
        in_tractogram, args.error_rate)

    out_tractogram = LazyTractogram(compressed_streamlines,
                                    affine_to_rasmm=np.eye(4))
    nib.streamlines.save(out_tractogram,
                         args.out_tractogram,
                         header=in_tractogram.header)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_tractograms)
    assert_outputs_exist(parser, args, args.out_tractogram,
                         optional=args.save_indices)

    if args.operation == 'lazy_concatenate':
        logging.info('Using lazy_concatenate, no spatial or metadata related '
                     'checks are performed.\nMetadata will be lost, only '
                     'trk/tck file are supported.')

        def list_generator_from_nib(filenames):
            for in_file in filenames:
                tractogram_file = nib.streamlines.load(in_file, lazy_load=True)
                for s in tractogram_file.streamlines:
                    yield s
        header = None
        for in_file in args.in_tractograms:
            _, ext = os.path.splitext(in_file)
            if ext == '.trk':
                if header is None:
                    header = nib.streamlines.load(
                        in_file, lazy_load=True).header
                elif not is_header_compatible(header, in_file):
                    logging.warning('Incompatible headers in the list.')

        generator = list_generator_from_nib(args.in_tractograms)
        out_tractogram = LazyTractogram(lambda: generator,
                                        affine_to_rasmm=np.eye(4))
        nib.streamlines.save(out_tractogram, args.out_tractogram,
                             header=header)
        return

    # Load all input streamlines.
    sft_list = []
    for f in args.in_tractograms:
        sft_list.append(load_tractogram_with_reference(
            parser, args, f, bbox_check=not args.ignore_invalid))

    # Apply the requested operation to each input file.
    logging.info('Performing operation \'{}\'.'.format(args.operation))
    new_sft = concatenate_sft(sft_list, args.no_metadata, args.fake_metadata)
    if args.operation == 'concatenate':
        indices = np.arange(len(new_sft), dtype=np.uint32)
    else:
        streamlines_list = [sft.streamlines for sft in sft_list]
        op_name = args.operation
        if args.robust:
            op_name += '_robust'
            _, indices = OPERATIONS[op_name](streamlines_list,
                                             precision=args.precision)
        else:
            _, indices = perform_streamlines_operation(
                OPERATIONS[op_name], streamlines_list,
                precision=args.precision)

    # Save the indices to a file if requested.
    if args.save_indices:
        start = 0
        out_dict = {}
        streamlines_len_cumsum = [len(sft) for sft in sft_list]
        for name, nb in zip(args.in_tractograms, streamlines_len_cumsum):
            end = start + nb
            # Switch to int32 for json
            out_dict[name] = [int(i - start)
                              for i in indices if start <= i < end]
            start = end

        with open(args.save_indices, 'wt') as f:
            json.dump(out_dict, f,
                      indent=args.indent,
                      sort_keys=args.sort_keys)

    # Save the new streamlines (and metadata)
    logging.info('Saving {} streamlines to {}.'.format(len(indices),
                                                       args.out_tractogram))
    save_tractogram(new_sft[indices], args.out_tractogram,
                    bbox_valid_check=not args.ignore_invalid)
Beispiel #3
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [
        args.sh_file, args.seed_file, args.map_include_file,
        args.map_exclude_file
    ])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_data().astype(np.double),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    tissue_classifier = None
    if not args.act:
        tissue_classifier = CmcTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data(),
                                                step_size=args.step_size,
                                                average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size
    filtered_streamlines = (
        s for s in pft_streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)