Ejemplo n.º 1
0
def load_mask_classifier_dataset(subject_files,
                                 volume_manager,
                                 name="HCP",
                                 use_sh_coeffs=False):
    subjects = []
    with Timer("  Loading subject(s)", newline=True):
        for subject_file in sorted(subject_files):
            print("    {}".format(subject_file))
            mask_data = MaskClassifierData.load(subject_file)

            dwi = mask_data.signal
            bvals = mask_data.gradients.bvals
            bvecs = mask_data.gradients.bvecs
            if use_sh_coeffs:
                # Use 45 spherical harmonic coefficients to represent the diffusion signal.
                volume = neurotools.get_spherical_harmonics_coefficients(
                    dwi, bvals, bvecs).astype(np.float32)
            else:
                # Resample the diffusion signal to have 100 directions.
                volume = neurotools.resample_dwi(dwi, bvals,
                                                 bvecs).astype(np.float32)

            mask_data.signal.uncache(
            )  # Free some memory as we don't need the original signal.
            subject_id = volume_manager.register(volume)
            mask_data.subject_id = subject_id
            subjects.append(mask_data)

    return MaskClassifierDataset(subjects, name, keep_on_cpu=True)
Ejemplo n.º 2
0
def load_tractography_dataset(subject_files, volume_manager, name="HCP", use_sh_coeffs=False):
    subjects = []
    with Timer("  Loading subject(s)", newline=True):
        for subject_file in sorted(subject_files):
            print("    {}".format(subject_file))
            tracto_data = TractographyData.load(subject_file)

            dwi = tracto_data.signal
            bvals = tracto_data.gradients.bvals
            bvecs = tracto_data.gradients.bvecs
            if use_sh_coeffs:
                # Use 45 spherical harmonic coefficients to represent the diffusion signal.
                volume = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32)
            else:
                # Resample the diffusion signal to have 100 directions.
                volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32)

            tracto_data.signal.uncache()  # Free some memory as we don't need the original signal.
            subject_id = volume_manager.register(volume)
            tracto_data.subject_id = subject_id
            subjects.append(tracto_data)

    return TractographyDataset(subjects, name, keep_on_cpu=True)
Ejemplo n.º 3
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]

        try:
            bvals_filename = dwi_name + ".bvals"
            bvecs_filename = dwi_name + ".bvecs"
            bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                bvals_filename, bvecs_filename)
        except FileNotFoundError:
            try:
                bvals_filename = dwi_name + ".bval"
                bvecs_filename = dwi_name + ".bvec"
                bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                    bvals_filename, bvecs_filename)
            except FileNotFoundError as e:
                print("Could not find .bvals/.bvecs or .bval/.bvec files...")
                raise e

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

        affine_rasmm2dwivox = np.linalg.inv(dwi.affine)

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression
            model_class = GRU_Regression
        elif hyperparams['model'] == 'gru_gaussian':
            from learn2track.models import GRU_Gaussian
            model_class = GRU_Gaussian
        elif hyperparams['model'] == 'gru_mixture':
            from learn2track.models import GRU_Mixture
            model_class = GRU_Mixture
        elif hyperparams['model'] == 'gru_multistep':
            from learn2track.models import GRU_Multistep_Gaussian
            model_class = GRU_Multistep_Gaussian
        elif hyperparams['model'] == 'ffnn_regression':
            from learn2track.models import FFNN_Regression
            model_class = FFNN_Regression
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        model.drop_prob = 0.
        print(str(model))

    mask = None
    if args.mask is not None:
        with Timer("Loading mask"):
            mask_nii = nib.load(args.mask)
            mask = mask_nii.get_data()
            # Compute the affine allowing to evaluate the mask at some coordinates correctly.

            # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox
            affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox,
                                           mask_nii.affine)
            if args.dilate_mask:
                import scipy
                mask = scipy.ndimage.morphology.binary_dilation(mask).astype(
                    mask.dtype)

    with Timer("Generating seeds"):
        seeds = []

        for filename in args.seeds:
            if filename.endswith('.trk') or filename.endswith('.tck'):
                tfile = nib.streamlines.load(filename)
                # Send the streamlines to voxel since that's where we'll track.
                tfile.tractogram.apply_affine(affine_rasmm2dwivox)

                # Use extremities of the streamlines as seeding points.
                seeds += [s[0] for s in tfile.streamlines]
                seeds += [s[-1] for s in tfile.streamlines]

            else:
                # Assume it is a binary mask.
                rng = np.random.RandomState(args.seeding_rng_seed)
                nii_seeds = nib.load(filename)

                # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox
                affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox,
                                                nii_seeds.affine)

                nii_seeds_data = nii_seeds.get_data()

                if args.dilate_seeding_mask:
                    import scipy
                    nii_seeds_data = scipy.ndimage.morphology.binary_dilation(
                        nii_seeds_data).astype(nii_seeds_data.dtype)

                indices = np.array(np.where(nii_seeds_data)).T
                for idx in indices:
                    seeds_in_voxel = idx + rng.uniform(
                        -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3))
                    seeds_in_voxel = nib.affines.apply_affine(
                        affine_seedsvox2dwivox, seeds_in_voxel)
                    seeds.extend(seeds_in_voxel)

        seeds = np.array(seeds, dtype=theano.config.floatX)

    with Timer("Tracking in the diffusion voxel space"):
        voxel_sizes = np.asarray(dwi.header.get_zooms()[:3])
        if not np.all(voxel_sizes == dwi.header.get_zooms()[0]):
            print("* Careful voxel are anisotropic {}!".format(
                tuple(voxel_sizes)))
        # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel.

        if args.step_size is not None:
            step_size = np.float32(args.step_size / voxel_sizes.max())
            # Also convert max length (in mm) to voxel.
            max_nb_points = int(np.ceil(args.max_length / args.step_size))
        else:
            step_size = None
            max_nb_points = args.max_length

        if args.theta is not None:
            theta = np.deg2rad(args.theta)
        elif args.curvature is not None and args.curvature > 0:
            theta = get_max_angle_from_curvature(args.curvature, step_size)
        else:
            theta = np.deg2rad(45)

        print("Angle: {}".format(np.rad2deg(theta)))
        print("Step size (vox): {}".format(step_size))
        print("Max nb. points: {}".format(max_nb_points))

        is_outside_mask = make_is_outside_mask(mask,
                                               affine_maskvox2dwivox,
                                               threshold=args.mask_threshold)
        is_too_long = make_is_too_long(max_nb_points)
        is_too_curvy = make_is_too_curvy(np.rad2deg(theta))
        is_unlikely = make_is_unlikely(0.5)
        is_stopping = make_is_stopping({
            STOPPING_MASK: is_outside_mask,
            STOPPING_LENGTH: is_too_long,
            STOPPING_CURVATURE: is_too_curvy,
            STOPPING_LIKELIHOOD: is_unlikely
        })

        is_stopping.max_nb_points = max_nb_points  # Small hack

        tractogram = batch_track(model,
                                 weights,
                                 seeds,
                                 step_size=step_size,
                                 is_stopping=is_stopping,
                                 batch_size=args.batch_size,
                                 args=args)

        # Streamlines have been generated in voxel space.
        # Transform them them back to RAS+mm space using the dwi's affine.
        tractogram.affine_to_rasmm = dwi.affine
        tractogram.to_world()  # Performed in-place.

    nb_streamlines = len(tractogram)

    if args.save_rejected:
        rejected_tractogram = Tractogram()
        rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm

    print("Generated {:,} (compressed) streamlines".format(nb_streamlines))
    with Timer("Cleaning streamlines", newline=True):
        # Flush streamlines that have no points.
        if args.save_rejected:
            rejected_tractogram += tractogram[
                np.array(list(map(len, tractogram))) <= 0]

        tractogram = tractogram[np.array(list(map(len, tractogram))) > 0]
        print("Removed {:,} empty streamlines".format(nb_streamlines -
                                                      len(tractogram)))

        # Remove small streamlines
        nb_streamlines = len(tractogram)
        lengths = dipy.tracking.streamline.length(tractogram.streamlines)

        if args.save_rejected:
            rejected_tractogram += tractogram[lengths < args.min_length]

        tractogram = tractogram[lengths >= args.min_length]
        lengths = lengths[lengths >= args.min_length]
        if len(lengths) > 0:
            print("Average length: {:.2f} mm.".format(lengths.mean()))
            print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(
                lengths.min(), lengths.max()))
        print("Removed {:,} streamlines smaller than {:.2f} mm".format(
            nb_streamlines - len(tractogram), args.min_length))
        if args.discard_stopped_by_curvature:
            nb_streamlines = len(tractogram)
            stopping_curvature_flag_is_set = is_flag_set(
                tractogram.data_per_streamline['stopping_flags'][:, 0],
                STOPPING_CURVATURE)

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    stopping_curvature_flag_is_set]

            tractogram = tractogram[np.logical_not(
                stopping_curvature_flag_is_set)]
            print(
                "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree"
                .format(nb_streamlines - len(tractogram), np.rad2deg(theta)))

        if args.filter_threshold is not None:
            # Remove streamlines that produces a reconstruction error higher than a certain threshold.
            nb_streamlines = len(tractogram)
            losses = compute_loss_errors(tractogram.streamlines, model,
                                         hyperparams)
            print("Mean loss: {:.4f} ± {:.4f}".format(
                np.mean(losses),
                np.std(losses, ddof=1) / np.sqrt(len(losses))))

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    losses > args.filter_threshold]

            tractogram = tractogram[losses <= args.filter_threshold]
            print(
                "Removed {:,} streamlines producing a loss lower than {:.2f} mm"
                .format(nb_streamlines - len(tractogram),
                        args.filter_threshold))

    with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            seed_mask_type = args.seeds[0].replace(".", "_").replace(
                "_", "").replace("/", "-")
            if "int" in args.seeds[0]:
                seed_mask_type = "int"
            elif "wm" in args.seeds[0]:
                seed_mask_type = "wm"
            elif "rois" in args.seeds[0]:
                seed_mask_type = "rois"
            elif "bundles" in args.seeds[0]:
                seed_mask_type = "bundles"

            mask_type = ""
            if "fa" in args.mask:
                mask_type = "fa"
            elif "wm" in args.mask:
                mask_type = "wm"

            if args.dilate_seeding_mask:
                seed_mask_type += "D"

            if args.dilate_mask:
                mask_type += "D"

            filename_items = [
                "{}",
                "useMaxComponent-{}",
                # "seed-{}",
                # "mask-{}",
                "step-{:.2f}mm",
                "nbSeeds-{}",
                "maxAngleDeg-{:.1f}"
                # "keepCurv-{}",
                # "filtered-{}",
                # "minLen-{}",
                # "pftRetry-{}",
                # "pftHist-{}",
                # "trackLikePeter-{}",
            ]
            filename = ('_'.join(filename_items) + ".tck").format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        nib.streamlines.save(tractogram, save_path)

    if args.save_rejected:
        with Timer("Saving {:,} (compressed) rejected streamlines".format(
                len(rejected_tractogram))):
            rejected_filename_items = filename_items.copy()
            rejected_filename_items.insert(1, "rejected")
            rejected_filename = (
                '_'.join(rejected_filename_items) + ".tck"
            ).format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        rejected_save_path = pjoin(experiment_path, rejected_filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(rejected_save_path))
        except:
            pass

        print("Saving rejected streamlines to {}".format(rejected_save_path))
        nib.streamlines.save(rejected_tractogram, rejected_save_path)
Ejemplo n.º 4
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]
        bvals_filename = dwi_name + ".bvals"
        bvecs_filename = dwi_name + ".bvecs"
        bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(bvals_filename, bvecs_filename)

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32)

        affine_rasmm2dwivox = np.linalg.inv(dwi.affine)

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression
            model_class = GRU_Regression
        elif hyperparams['model'] == 'gru_mixture':
            from learn2track.models import GRU_Mixture
            model_class = GRU_Mixture
        elif hyperparams['model'] == 'gru_multistep':
            from learn2track.models import GRU_Multistep_Gaussian
            model_class = GRU_Multistep_Gaussian
        elif hyperparams['model'] == 'ffnn_regression':
            from learn2track.models import FFNN_Regression
            model_class = FFNN_Regression
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(pjoin(experiment_path), **kwargs)  # Create new instance and restore model.
        print(str(model))

    mask = None
    if args.mask is not None:
        with Timer("Loading mask"):
            mask_nii = nib.load(args.mask)
            mask = mask_nii.get_data()
            # Compute the affine allowing to evaluate the mask at some coordinates correctly.

            # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox
            affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine)
            if args.dilate_mask:
                import scipy
                mask = scipy.ndimage.morphology.binary_dilation(mask).astype(mask.dtype)

    with Timer("Generating seeds"):
        seeds = []

        for filename in args.seeds:
            if filename.endswith('.trk') or filename.endswith('.tck'):
                tfile = nib.streamlines.load(filename)
                # Send the streamlines to voxel since that's where we'll track.
                tfile.tractogram.apply_affine(affine_rasmm2dwivox)

                # Use extremities of the streamlines as seeding points.
                seeds += [s[0] for s in tfile.streamlines]
                seeds += [s[-1] for s in tfile.streamlines]

            else:
                # Assume it is a binary mask.
                rng = np.random.RandomState(args.seeding_rng_seed)
                nii_seeds = nib.load(filename)

                # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox
                affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine)

                indices = np.array(np.where(nii_seeds.get_data())).T
                for idx in indices:
                    seeds_in_voxel = idx + rng.uniform(-0.5, 0.5, size=(args.nb_seeds_per_voxel, 3))
                    seeds_in_voxel = nib.affines.apply_affine(affine_seedsvox2dwivox, seeds_in_voxel)
                    seeds.extend(seeds_in_voxel)

        seeds = np.array(seeds, dtype=theano.config.floatX)

    with Timer("Tracking in the diffusion voxel space"):
        voxel_sizes = np.asarray(dwi.header.get_zooms()[:3])
        if not np.all(voxel_sizes == dwi.header.get_zooms()[0]):
            print("* Careful voxel are anisotropic {}!".format(tuple(voxel_sizes)))
        # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel.

        if args.step_size is not None:
            step_size = np.float32(args.step_size / voxel_sizes.max())
            # Also convert max length (in mm) to voxel.
            max_nb_points = int(args.max_length / step_size)
        else:
            step_size = None
            max_nb_points = args.max_length

        if args.theta is not None:
            theta = np.deg2rad(args.theta)
        elif args.curvature is not None and args.curvature > 0:
            theta = get_max_angle_from_curvature(args.curvature, step_size)
        else:
            theta = np.deg2rad(45)

        print("Angle: {}".format(np.rad2deg(theta)))
        print("Step size (vox): {}".format(step_size))
        print("Max nb. points: {}".format(max_nb_points))

        is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold)
        is_too_long = make_is_too_long(max_nb_points)
        is_too_curvy = make_is_too_curvy(np.rad2deg(theta))
        is_stopping = make_is_stopping({STOPPING_MASK: is_outside_mask,
                                        STOPPING_LENGTH: is_too_long,
                                        STOPPING_CURVATURE: is_too_curvy})

        is_stopping.max_nb_points = max_nb_points  # Small hack

        tractogram = batch_track(model, weights, seeds,
                                 step_size=step_size,
                                 is_stopping=is_stopping,
                                 batch_size=args.batch_size,
                                 args=args)

        # Streamlines have been generated in voxel space.
        # Transform them them back to RAS+mm space using the dwi's affine.
        tractogram.affine_to_rasmm = dwi.affine
        tractogram.to_world()  # Performed in-place.

    nb_streamlines = len(tractogram)
    print("Generated {:,} (compressed) streamlines".format(nb_streamlines))
    with Timer("Cleaning streamlines", newline=True):
        # Flush streamlines that has no points.
        tractogram = tractogram[np.array(list(map(len, tractogram))) > 0]
        print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram)))

        # Remove small streamlines
        nb_streamlines = len(tractogram)
        lengths = dipy.tracking.streamline.length(tractogram.streamlines)
        tractogram = tractogram[lengths >= args.min_length]
        lengths = lengths[lengths >= args.min_length]
        if len(lengths) > 0:
            print("Average length: {:.2f} mm.".format(lengths.mean()))
            print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(lengths.min(), lengths.max()))
        print("Removed {:,} streamlines smaller than {:.2f} mm".format(nb_streamlines - len(tractogram),
                                                                       args.min_length))
        if args.discard_stopped_by_curvature:
            nb_streamlines = len(tractogram)
            stopping_curvature_flag_is_set = is_flag_set(tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE)
            tractogram = tractogram[stopping_curvature_flag_is_set]
            print("Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree".format(nb_streamlines - len(tractogram),
                                                                                                             np.rad2deg(theta)))

        if args.filter_threshold is not None:
            # Remove streamlines that produces a reconstruction error higher than a certain threshold.
            nb_streamlines = len(tractogram)
            losses = compute_loss_errors(tractogram.streamlines, model, hyperparams)
            print("Mean loss: {:.4f} ± {:.4f}".format(np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses))))
            tractogram = tractogram[losses <= args.filter_threshold]
            print("Removed {:,} streamlines producing a loss lower than {:.2f} mm".format(nb_streamlines - len(tractogram),
                                                                                          args.filter_threshold))

    with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            mask_type = args.seeds[0].replace(".", "_").replace("_", "")
            if "int" in args.seeds[0]:
                mask_type = "int"
            elif "wm" in args.seeds[0]:
                mask_type = "wm"
            elif "rois" in args.seeds[0]:
                mask_type = "rois"

            filename = "{}-{}_seeding-{}_step-{:.2f}mm_nbSeeds-{}_maxAngle-{:.1f}deg_keepCurv-{}_filtered-{}_minLen-{}_pftRetry-{}_pftHist-{}_useMaxComponent-{}.tck".format(
                os.path.basename(args.name.rstrip('/'))[:6],
                prefix,
                mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta),
                not args.discard_stopped_by_curvature,
                args.filter_threshold,
                args.min_length,
                args.pft_nb_retry,
                args.pft_nb_backtrack_steps,
                args.use_max_component
                )

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        nib.streamlines.save(tractogram, save_path)
Ejemplo n.º 5
0
def load_tractography_dataset_from_dwi_and_tractogram(dwi,
                                                      tractogram,
                                                      volume_manager,
                                                      use_sh_coeffs=False,
                                                      bvals=None,
                                                      bvecs=None,
                                                      step_size=None,
                                                      mean_centering=True):
    # Load signal
    signal = nib.load(dwi)
    signal.get_data()  # Forces loading volume in-memory.
    basename = re.sub('(\.gz|\.nii.gz)$', '', dwi)
    bvals = basename + '.bvals' if bvals is None else bvals
    bvecs = basename + '.bvecs' if bvecs is None else bvecs

    gradients = gradient_table(bvals, bvecs)
    tracto_data = TractographyData(signal, gradients)

    # Load streamlines
    tfile = nib.streamlines.load(tractogram)
    tractogram = tfile.tractogram

    # Resample streamline to have a fixed step size, if needed.
    if step_size is not None:
        print("Resampling streamlines to have a step size of {}mm".format(
            step_size))
        streamlines = tractogram.streamlines
        streamlines._lengths = streamlines._lengths.astype(int)
        streamlines._offsets = streamlines._offsets.astype(int)
        lengths = length(streamlines)
        nb_points = np.ceil(lengths / step_size).astype(int)
        new_streamlines = (set_number_of_points(s, n)
                           for s, n in zip(streamlines, nb_points))
        tractogram = nib.streamlines.Tractogram(new_streamlines,
                                                affine_to_rasmm=np.eye(4))

    # Compute matrix that brings streamlines back to diffusion voxel space.
    rasmm2vox_affine = np.linalg.inv(signal.affine)
    tractogram.apply_affine(rasmm2vox_affine)

    # Add streamlines to the TractogramData
    tracto_data.add(tractogram.streamlines, "tractogram")

    dwi = tracto_data.signal
    bvals = tracto_data.gradients.bvals
    bvecs = tracto_data.gradients.bvecs

    if use_sh_coeffs:
        # Use 45 spherical harmonic coefficients to represent the diffusion signal.
        volume = neurotools.get_spherical_harmonics_coefficients(
            dwi, bvals, bvecs,
            mean_centering=mean_centering).astype(np.float32)
    else:
        # Resample the diffusion signal to have 100 directions.
        volume = neurotools.resample_dwi(dwi,
                                         bvals,
                                         bvecs,
                                         mean_centering=mean_centering).astype(
                                             np.float32)

    tracto_data.signal.uncache(
    )  # Free some memory as we don't need the original signal.
    subject_id = volume_manager.register(volume)
    tracto_data.subject_id = subject_id

    return TractographyDataset([tracto_data], "dataset", keep_on_cpu=True)
Ejemplo n.º 6
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]
        bvals_filename = dwi_name + ".bvals"
        bvecs_filename = dwi_name + ".bvecs"
        bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
            bvals_filename, bvecs_filename)

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

    with Timer("Loading model"):
        if hyperparams["model"] == "ffnn_classification":
            from learn2track.models import FFNN_Classification
            model_class = FFNN_Classification
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        print(str(model))

    with Timer("Generating mask"):
        symb_input = T.matrix(name="input")
        model_symb_pred = model.get_output(symb_input)
        f = theano.function(inputs=[symb_input], outputs=[model_symb_pred])

        generated_mask = np.zeros(dwi.shape[:3]).astype(np.float32)

        # all_coords.shape = (n_coords, 3)
        all_coords = np.argwhere(generated_mask == 0)

        volume_ids = np.zeros((all_coords.shape[0], 1))
        all_coords_and_volume_ids = np.concatenate((all_coords, volume_ids),
                                                   axis=1).astype(np.float32)

        batch_size = args.batch_size if args.batch_size else len(
            all_coords_and_volume_ids)
        probs = []
        while batch_size > 1:
            print("Trying to to process batches of size {} out of {}".format(
                batch_size, len(all_coords_and_volume_ids)))
            nb_batches = int(
                np.ceil(len(all_coords_and_volume_ids) / batch_size))
            try:
                for batch_count in range(nb_batches):
                    start = batch_count * batch_size
                    end = (batch_count + 1) * batch_size
                    probs.extend(f(all_coords_and_volume_ids[start:end])[-1])
                    print("Generated batch {} out of {}".format(
                        batch_count + 1, nb_batches))
                break
            except MemoryError:
                print("{} coordinates at the same time is too much!".format(
                    batch_size))
                batch_size //= 2
            except RuntimeError:
                print("{} coordinates at the same time is too much!".format(
                    batch_size))
                batch_size //= 2
        if not probs:
            raise RuntimeError("Could not generate predictions...")

        generated_mask[np.where(generated_mask == 0)] = np.array(probs) > 0.5

    with Timer("Saving generated mask"):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            filename = "{}.nii.gz".format(prefix)

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        mask = nib.Nifti1Image(generated_mask, dwi.affine)
        nib.save(mask, save_path)