Exemplo n.º 1
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]

        try:
            bvals_filename = dwi_name + ".bvals"
            bvecs_filename = dwi_name + ".bvecs"
            bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                bvals_filename, bvecs_filename)
        except FileNotFoundError:
            try:
                bvals_filename = dwi_name + ".bval"
                bvecs_filename = dwi_name + ".bvec"
                bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                    bvals_filename, bvecs_filename)
            except FileNotFoundError as e:
                print("Could not find .bvals/.bvecs or .bval/.bvec files...")
                raise e

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

        affine_rasmm2dwivox = np.linalg.inv(dwi.affine)

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression
            model_class = GRU_Regression
        elif hyperparams['model'] == 'gru_gaussian':
            from learn2track.models import GRU_Gaussian
            model_class = GRU_Gaussian
        elif hyperparams['model'] == 'gru_mixture':
            from learn2track.models import GRU_Mixture
            model_class = GRU_Mixture
        elif hyperparams['model'] == 'gru_multistep':
            from learn2track.models import GRU_Multistep_Gaussian
            model_class = GRU_Multistep_Gaussian
        elif hyperparams['model'] == 'ffnn_regression':
            from learn2track.models import FFNN_Regression
            model_class = FFNN_Regression
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        model.drop_prob = 0.
        print(str(model))

    mask = None
    if args.mask is not None:
        with Timer("Loading mask"):
            mask_nii = nib.load(args.mask)
            mask = mask_nii.get_data()
            # Compute the affine allowing to evaluate the mask at some coordinates correctly.

            # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox
            affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox,
                                           mask_nii.affine)
            if args.dilate_mask:
                import scipy
                mask = scipy.ndimage.morphology.binary_dilation(mask).astype(
                    mask.dtype)

    with Timer("Generating seeds"):
        seeds = []

        for filename in args.seeds:
            if filename.endswith('.trk') or filename.endswith('.tck'):
                tfile = nib.streamlines.load(filename)
                # Send the streamlines to voxel since that's where we'll track.
                tfile.tractogram.apply_affine(affine_rasmm2dwivox)

                # Use extremities of the streamlines as seeding points.
                seeds += [s[0] for s in tfile.streamlines]
                seeds += [s[-1] for s in tfile.streamlines]

            else:
                # Assume it is a binary mask.
                rng = np.random.RandomState(args.seeding_rng_seed)
                nii_seeds = nib.load(filename)

                # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox
                affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox,
                                                nii_seeds.affine)

                nii_seeds_data = nii_seeds.get_data()

                if args.dilate_seeding_mask:
                    import scipy
                    nii_seeds_data = scipy.ndimage.morphology.binary_dilation(
                        nii_seeds_data).astype(nii_seeds_data.dtype)

                indices = np.array(np.where(nii_seeds_data)).T
                for idx in indices:
                    seeds_in_voxel = idx + rng.uniform(
                        -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3))
                    seeds_in_voxel = nib.affines.apply_affine(
                        affine_seedsvox2dwivox, seeds_in_voxel)
                    seeds.extend(seeds_in_voxel)

        seeds = np.array(seeds, dtype=theano.config.floatX)

    with Timer("Tracking in the diffusion voxel space"):
        voxel_sizes = np.asarray(dwi.header.get_zooms()[:3])
        if not np.all(voxel_sizes == dwi.header.get_zooms()[0]):
            print("* Careful voxel are anisotropic {}!".format(
                tuple(voxel_sizes)))
        # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel.

        if args.step_size is not None:
            step_size = np.float32(args.step_size / voxel_sizes.max())
            # Also convert max length (in mm) to voxel.
            max_nb_points = int(np.ceil(args.max_length / args.step_size))
        else:
            step_size = None
            max_nb_points = args.max_length

        if args.theta is not None:
            theta = np.deg2rad(args.theta)
        elif args.curvature is not None and args.curvature > 0:
            theta = get_max_angle_from_curvature(args.curvature, step_size)
        else:
            theta = np.deg2rad(45)

        print("Angle: {}".format(np.rad2deg(theta)))
        print("Step size (vox): {}".format(step_size))
        print("Max nb. points: {}".format(max_nb_points))

        is_outside_mask = make_is_outside_mask(mask,
                                               affine_maskvox2dwivox,
                                               threshold=args.mask_threshold)
        is_too_long = make_is_too_long(max_nb_points)
        is_too_curvy = make_is_too_curvy(np.rad2deg(theta))
        is_unlikely = make_is_unlikely(0.5)
        is_stopping = make_is_stopping({
            STOPPING_MASK: is_outside_mask,
            STOPPING_LENGTH: is_too_long,
            STOPPING_CURVATURE: is_too_curvy,
            STOPPING_LIKELIHOOD: is_unlikely
        })

        is_stopping.max_nb_points = max_nb_points  # Small hack

        tractogram = batch_track(model,
                                 weights,
                                 seeds,
                                 step_size=step_size,
                                 is_stopping=is_stopping,
                                 batch_size=args.batch_size,
                                 args=args)

        # Streamlines have been generated in voxel space.
        # Transform them them back to RAS+mm space using the dwi's affine.
        tractogram.affine_to_rasmm = dwi.affine
        tractogram.to_world()  # Performed in-place.

    nb_streamlines = len(tractogram)

    if args.save_rejected:
        rejected_tractogram = Tractogram()
        rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm

    print("Generated {:,} (compressed) streamlines".format(nb_streamlines))
    with Timer("Cleaning streamlines", newline=True):
        # Flush streamlines that have no points.
        if args.save_rejected:
            rejected_tractogram += tractogram[
                np.array(list(map(len, tractogram))) <= 0]

        tractogram = tractogram[np.array(list(map(len, tractogram))) > 0]
        print("Removed {:,} empty streamlines".format(nb_streamlines -
                                                      len(tractogram)))

        # Remove small streamlines
        nb_streamlines = len(tractogram)
        lengths = dipy.tracking.streamline.length(tractogram.streamlines)

        if args.save_rejected:
            rejected_tractogram += tractogram[lengths < args.min_length]

        tractogram = tractogram[lengths >= args.min_length]
        lengths = lengths[lengths >= args.min_length]
        if len(lengths) > 0:
            print("Average length: {:.2f} mm.".format(lengths.mean()))
            print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(
                lengths.min(), lengths.max()))
        print("Removed {:,} streamlines smaller than {:.2f} mm".format(
            nb_streamlines - len(tractogram), args.min_length))
        if args.discard_stopped_by_curvature:
            nb_streamlines = len(tractogram)
            stopping_curvature_flag_is_set = is_flag_set(
                tractogram.data_per_streamline['stopping_flags'][:, 0],
                STOPPING_CURVATURE)

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    stopping_curvature_flag_is_set]

            tractogram = tractogram[np.logical_not(
                stopping_curvature_flag_is_set)]
            print(
                "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree"
                .format(nb_streamlines - len(tractogram), np.rad2deg(theta)))

        if args.filter_threshold is not None:
            # Remove streamlines that produces a reconstruction error higher than a certain threshold.
            nb_streamlines = len(tractogram)
            losses = compute_loss_errors(tractogram.streamlines, model,
                                         hyperparams)
            print("Mean loss: {:.4f} ± {:.4f}".format(
                np.mean(losses),
                np.std(losses, ddof=1) / np.sqrt(len(losses))))

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    losses > args.filter_threshold]

            tractogram = tractogram[losses <= args.filter_threshold]
            print(
                "Removed {:,} streamlines producing a loss lower than {:.2f} mm"
                .format(nb_streamlines - len(tractogram),
                        args.filter_threshold))

    with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            seed_mask_type = args.seeds[0].replace(".", "_").replace(
                "_", "").replace("/", "-")
            if "int" in args.seeds[0]:
                seed_mask_type = "int"
            elif "wm" in args.seeds[0]:
                seed_mask_type = "wm"
            elif "rois" in args.seeds[0]:
                seed_mask_type = "rois"
            elif "bundles" in args.seeds[0]:
                seed_mask_type = "bundles"

            mask_type = ""
            if "fa" in args.mask:
                mask_type = "fa"
            elif "wm" in args.mask:
                mask_type = "wm"

            if args.dilate_seeding_mask:
                seed_mask_type += "D"

            if args.dilate_mask:
                mask_type += "D"

            filename_items = [
                "{}",
                "useMaxComponent-{}",
                # "seed-{}",
                # "mask-{}",
                "step-{:.2f}mm",
                "nbSeeds-{}",
                "maxAngleDeg-{:.1f}"
                # "keepCurv-{}",
                # "filtered-{}",
                # "minLen-{}",
                # "pftRetry-{}",
                # "pftHist-{}",
                # "trackLikePeter-{}",
            ]
            filename = ('_'.join(filename_items) + ".tck").format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        nib.streamlines.save(tractogram, save_path)

    if args.save_rejected:
        with Timer("Saving {:,} (compressed) rejected streamlines".format(
                len(rejected_tractogram))):
            rejected_filename_items = filename_items.copy()
            rejected_filename_items.insert(1, "rejected")
            rejected_filename = (
                '_'.join(rejected_filename_items) + ".tck"
            ).format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        rejected_save_path = pjoin(experiment_path, rejected_filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(rejected_save_path))
        except:
            pass

        print("Saving rejected streamlines to {}".format(rejected_save_path))
        nib.streamlines.save(rejected_tractogram, rejected_save_path)
Exemplo n.º 2
0
def main():

    parser = build_args_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if os.path.isfile(args.output):
        if args.force:
            logging.info('Overwriting {0}.'.format(args.output))
        else:
            parser.error('{0} already exist! Use -f to overwrite it.'.format(
                args.output))

    # Load all input streamlines.
    data = [load_data(f) for f in args.inputs]
    streamlines, data_per_streamline, data_per_point = zip(*data)
    nb_streamlines = [len(s) for s in streamlines]

    # Apply the requested operation to each input file.
    logging.info('Performing operation \'{}\'.'.format(args.operation))
    new_streamlines, indices = perform_streamlines_operation(
        OPERATIONS[args.operation], streamlines, args.precision)

    # Get the meta data of the streamlines.
    new_data_per_streamline = {}
    new_data_per_point = {}
    if not args.no_data:

        for key in data_per_streamline[0].keys():
            all_data = np.vstack([s[key] for s in data_per_streamline])
            new_data_per_streamline[key] = all_data[indices, :]

        # Add the indices to the metadata if requested.
        if args.save_meta_indices:
            new_data_per_streamline['ids'] = indices

        for key in data_per_point[0].keys():
            all_data = list(chain(*[s[key] for s in data_per_point]))
            new_data_per_point[key] = [all_data[i] for i in indices]

    # Save the indices to a file if requested.
    if args.save_indices is not None:
        start = 0
        indices_dict = {'filenames': args.inputs}
        for name, nb in zip(args.inputs, nb_streamlines):
            end = start + nb
            file_indices = \
                [i - start for i in indices if i >= start and i < end]
            indices_dict[name] = file_indices
            start = end
        with open(args.save_indices, 'wt') as f:
            json.dump(indices_dict, f)

    # Save the new streamlines.
    logging.info('Saving streamlines to {0}.'.format(args.output))
    reference_file = load(args.inputs[0], True)
    new_tractogram = Tractogram(new_streamlines,
                                data_per_streamline=new_data_per_streamline,
                                data_per_point=new_data_per_point)

    # If the reference is a .tck, the affine will be None.
    affine = reference_file.tractogram.affine_to_rasmm
    if affine is None:
        affine = np.eye(4)
    new_tractogram.affine_to_rasmm = affine

    new_header = reference_file.header.copy()
    new_header['nb_streamlines'] = len(new_streamlines)
    save(new_tractogram, args.output, header=new_header)