def main(): parser = build_argparser() args = parser.parse_args() fig, ax = plt.subplots(1) colors = ['blue', 'orange', 'magenta', 'pink', 'darkgreen'] for subject_file, color in zip(args.subjects, colors): subject_id = os.path.basename(subject_file)[:-4] print("Loading {}...".format(subject_id)) tracto_data = neurotools.TractographyData.load(subject_file) dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) idx = volume.sum(axis=-1).nonzero() means = volume[idx].mean(axis=0) stds = 0.1 * volume[idx].std(axis=0) t = np.arange(1, len(means)+1) ax.plot(t, means, lw=2, label="mean {}".format(subject_id), color=color) ax.fill_between(t, means+stds, means-stds, facecolor=color, alpha=0.5) plt.legend() plt.show()
def load_mask_classifier_dataset(subject_files, volume_manager, name="HCP", use_sh_coeffs=False): subjects = [] with Timer(" Loading subject(s)", newline=True): for subject_file in sorted(subject_files): print(" {}".format(subject_file)) mask_data = MaskClassifierData.load(subject_file) dwi = mask_data.signal bvals = mask_data.gradients.bvals bvecs = mask_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) mask_data.signal.uncache( ) # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) mask_data.subject_id = subject_id subjects.append(mask_data) return MaskClassifierDataset(subjects, name, keep_on_cpu=True)
def make_dummy_dataset(volume_manager, nb_subjects=3, seed=1234): rng = np.random.RandomState(seed) nb_bundles = 7 nb_gradients = 64 subjects = [] for subject_id in range(nb_subjects): volume_shape = np.array((rng.randint(5, 30), rng.randint(5, 30), rng.randint(5, 30), nb_gradients)) dwi = nib.Nifti1Image(rng.rand(*volume_shape), affine=np.eye(4)) bvals = [0] + [1000] * (nb_gradients - 1) bvecs = rng.randn(nb_gradients, 3) bvecs /= np.sqrt(np.sum(bvecs**2, axis=1, keepdims=True)) gradients = gradient_table(bvals, bvecs) volume = neurotools.resample_dwi(dwi, gradients.bvals, gradients.bvecs).astype(np.float32) tracto_data = neurotools.TractographyData(dwi, gradients) for bundle_id in range(nb_bundles): streamlines = [ rng.randn(rng.randint(5, 100), 3) * 5 + volume_shape[:3] / 2. for i in range(rng.randint(5, 30)) ] tracto_data.add(streamlines, "bundle_{}".format(bundle_id)) subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id subjects.append(tracto_data) return datasets.TractographyDataset(subjects, name="test", keep_on_cpu=True)
def main(): parser = build_argparser() args = parser.parse_args() fig, ax = plt.subplots(1) colors = ['blue', 'orange', 'magenta', 'pink', 'darkgreen'] for subject_file, color in zip(args.subjects, colors): subject_id = os.path.basename(subject_file)[:-4] print("Loading {}...".format(subject_id)) tracto_data = neurotools.TractographyData.load(subject_file) dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) idx = volume.sum(axis=-1).nonzero() means = volume[idx].mean(axis=0) stds = 0.1 * volume[idx].std(axis=0) t = np.arange(1, len(means) + 1) ax.plot(t, means, lw=2, label="mean {}".format(subject_id), color=color) ax.fill_between(t, means + stds, means - stds, facecolor=color, alpha=0.5) plt.legend() plt.show()
def load_tractography_dataset(subject_files, volume_manager, name="HCP", use_sh_coeffs=False): subjects = [] with Timer(" Loading subject(s)", newline=True): for subject_file in sorted(subject_files): print(" {}".format(subject_file)) tracto_data = TractographyData.load(subject_file) dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) tracto_data.signal.uncache() # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id subjects.append(tracto_data) return TractographyDataset(subjects, name, keep_on_cpu=True)
def make_dummy_dataset(volume_manager, nb_subjects=3, seed=1234): rng = np.random.RandomState(seed) nb_bundles = 7 nb_gradients = 64 subjects = [] for subject_id in range(nb_subjects): dwi, gradients = make_dummy_dwi(nb_gradients, seed=seed) volume = neurotools.resample_dwi(dwi, gradients.bvals, gradients.bvecs).astype(np.float32) volume_shape = np.array(dwi.shape) tracto_data = neurotools.TractographyData(dwi, gradients) for bundle_id in range(nb_bundles): streamlines = [rng.randn(rng.randint(5, 100), 3) * 5 + volume_shape[:3]/2. for i in range(rng.randint(5, 30))] tracto_data.add(streamlines, "bundle_{}".format(bundle_id)) subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id subjects.append(tracto_data) return datasets.TractographyDataset(subjects, name="test", keep_on_cpu=True)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] try: bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError: try: bvals_filename = dwi_name + ".bval" bvecs_filename = dwi_name + ".bvec" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError as e: print("Could not find .bvals/.bvecs or .bval/.bvec files...") raise e dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) affine_rasmm2dwivox = np.linalg.inv(dwi.affine) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model_class = GRU_Regression elif hyperparams['model'] == 'gru_gaussian': from learn2track.models import GRU_Gaussian model_class = GRU_Gaussian elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture model_class = GRU_Mixture elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian model_class = GRU_Multistep_Gaussian elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression model_class = FFNN_Regression else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create( pjoin(experiment_path), **kwargs) # Create new instance and restore model. model.drop_prob = 0. print(str(model)) mask = None if args.mask is not None: with Timer("Loading mask"): mask_nii = nib.load(args.mask) mask = mask_nii.get_data() # Compute the affine allowing to evaluate the mask at some coordinates correctly. # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine) if args.dilate_mask: import scipy mask = scipy.ndimage.morphology.binary_dilation(mask).astype( mask.dtype) with Timer("Generating seeds"): seeds = [] for filename in args.seeds: if filename.endswith('.trk') or filename.endswith('.tck'): tfile = nib.streamlines.load(filename) # Send the streamlines to voxel since that's where we'll track. tfile.tractogram.apply_affine(affine_rasmm2dwivox) # Use extremities of the streamlines as seeding points. seeds += [s[0] for s in tfile.streamlines] seeds += [s[-1] for s in tfile.streamlines] else: # Assume it is a binary mask. rng = np.random.RandomState(args.seeding_rng_seed) nii_seeds = nib.load(filename) # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine) nii_seeds_data = nii_seeds.get_data() if args.dilate_seeding_mask: import scipy nii_seeds_data = scipy.ndimage.morphology.binary_dilation( nii_seeds_data).astype(nii_seeds_data.dtype) indices = np.array(np.where(nii_seeds_data)).T for idx in indices: seeds_in_voxel = idx + rng.uniform( -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3)) seeds_in_voxel = nib.affines.apply_affine( affine_seedsvox2dwivox, seeds_in_voxel) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) with Timer("Tracking in the diffusion voxel space"): voxel_sizes = np.asarray(dwi.header.get_zooms()[:3]) if not np.all(voxel_sizes == dwi.header.get_zooms()[0]): print("* Careful voxel are anisotropic {}!".format( tuple(voxel_sizes))) # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel. if args.step_size is not None: step_size = np.float32(args.step_size / voxel_sizes.max()) # Also convert max length (in mm) to voxel. max_nb_points = int(np.ceil(args.max_length / args.step_size)) else: step_size = None max_nb_points = args.max_length if args.theta is not None: theta = np.deg2rad(args.theta) elif args.curvature is not None and args.curvature > 0: theta = get_max_angle_from_curvature(args.curvature, step_size) else: theta = np.deg2rad(45) print("Angle: {}".format(np.rad2deg(theta))) print("Step size (vox): {}".format(step_size)) print("Max nb. points: {}".format(max_nb_points)) is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold) is_too_long = make_is_too_long(max_nb_points) is_too_curvy = make_is_too_curvy(np.rad2deg(theta)) is_unlikely = make_is_unlikely(0.5) is_stopping = make_is_stopping({ STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy, STOPPING_LIKELIHOOD: is_unlikely }) is_stopping.max_nb_points = max_nb_points # Small hack tractogram = batch_track(model, weights, seeds, step_size=step_size, is_stopping=is_stopping, batch_size=args.batch_size, args=args) # Streamlines have been generated in voxel space. # Transform them them back to RAS+mm space using the dwi's affine. tractogram.affine_to_rasmm = dwi.affine tractogram.to_world() # Performed in-place. nb_streamlines = len(tractogram) if args.save_rejected: rejected_tractogram = Tractogram() rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm print("Generated {:,} (compressed) streamlines".format(nb_streamlines)) with Timer("Cleaning streamlines", newline=True): # Flush streamlines that have no points. if args.save_rejected: rejected_tractogram += tractogram[ np.array(list(map(len, tractogram))) <= 0] tractogram = tractogram[np.array(list(map(len, tractogram))) > 0] print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram))) # Remove small streamlines nb_streamlines = len(tractogram) lengths = dipy.tracking.streamline.length(tractogram.streamlines) if args.save_rejected: rejected_tractogram += tractogram[lengths < args.min_length] tractogram = tractogram[lengths >= args.min_length] lengths = lengths[lengths >= args.min_length] if len(lengths) > 0: print("Average length: {:.2f} mm.".format(lengths.mean())) print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format( lengths.min(), lengths.max())) print("Removed {:,} streamlines smaller than {:.2f} mm".format( nb_streamlines - len(tractogram), args.min_length)) if args.discard_stopped_by_curvature: nb_streamlines = len(tractogram) stopping_curvature_flag_is_set = is_flag_set( tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE) if args.save_rejected: rejected_tractogram += tractogram[ stopping_curvature_flag_is_set] tractogram = tractogram[np.logical_not( stopping_curvature_flag_is_set)] print( "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree" .format(nb_streamlines - len(tractogram), np.rad2deg(theta))) if args.filter_threshold is not None: # Remove streamlines that produces a reconstruction error higher than a certain threshold. nb_streamlines = len(tractogram) losses = compute_loss_errors(tractogram.streamlines, model, hyperparams) print("Mean loss: {:.4f} ± {:.4f}".format( np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses)))) if args.save_rejected: rejected_tractogram += tractogram[ losses > args.filter_threshold] tractogram = tractogram[losses <= args.filter_threshold] print( "Removed {:,} streamlines producing a loss lower than {:.2f} mm" .format(nb_streamlines - len(tractogram), args.filter_threshold)) with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") seed_mask_type = args.seeds[0].replace(".", "_").replace( "_", "").replace("/", "-") if "int" in args.seeds[0]: seed_mask_type = "int" elif "wm" in args.seeds[0]: seed_mask_type = "wm" elif "rois" in args.seeds[0]: seed_mask_type = "rois" elif "bundles" in args.seeds[0]: seed_mask_type = "bundles" mask_type = "" if "fa" in args.mask: mask_type = "fa" elif "wm" in args.mask: mask_type = "wm" if args.dilate_seeding_mask: seed_mask_type += "D" if args.dilate_mask: mask_type += "D" filename_items = [ "{}", "useMaxComponent-{}", # "seed-{}", # "mask-{}", "step-{:.2f}mm", "nbSeeds-{}", "maxAngleDeg-{:.1f}" # "keepCurv-{}", # "filtered-{}", # "minLen-{}", # "pftRetry-{}", # "pftHist-{}", # "trackLikePeter-{}", ] filename = ('_'.join(filename_items) + ".tck").format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) nib.streamlines.save(tractogram, save_path) if args.save_rejected: with Timer("Saving {:,} (compressed) rejected streamlines".format( len(rejected_tractogram))): rejected_filename_items = filename_items.copy() rejected_filename_items.insert(1, "rejected") rejected_filename = ( '_'.join(rejected_filename_items) + ".tck" ).format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) rejected_save_path = pjoin(experiment_path, rejected_filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(rejected_save_path)) except: pass print("Saving rejected streamlines to {}".format(rejected_save_path)) nib.streamlines.save(rejected_tractogram, rejected_save_path)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(bvals_filename, bvecs_filename) dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) affine_rasmm2dwivox = np.linalg.inv(dwi.affine) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model_class = GRU_Regression elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture model_class = GRU_Mixture elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian model_class = GRU_Multistep_Gaussian elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression model_class = FFNN_Regression else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create(pjoin(experiment_path), **kwargs) # Create new instance and restore model. print(str(model)) mask = None if args.mask is not None: with Timer("Loading mask"): mask_nii = nib.load(args.mask) mask = mask_nii.get_data() # Compute the affine allowing to evaluate the mask at some coordinates correctly. # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine) if args.dilate_mask: import scipy mask = scipy.ndimage.morphology.binary_dilation(mask).astype(mask.dtype) with Timer("Generating seeds"): seeds = [] for filename in args.seeds: if filename.endswith('.trk') or filename.endswith('.tck'): tfile = nib.streamlines.load(filename) # Send the streamlines to voxel since that's where we'll track. tfile.tractogram.apply_affine(affine_rasmm2dwivox) # Use extremities of the streamlines as seeding points. seeds += [s[0] for s in tfile.streamlines] seeds += [s[-1] for s in tfile.streamlines] else: # Assume it is a binary mask. rng = np.random.RandomState(args.seeding_rng_seed) nii_seeds = nib.load(filename) # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine) indices = np.array(np.where(nii_seeds.get_data())).T for idx in indices: seeds_in_voxel = idx + rng.uniform(-0.5, 0.5, size=(args.nb_seeds_per_voxel, 3)) seeds_in_voxel = nib.affines.apply_affine(affine_seedsvox2dwivox, seeds_in_voxel) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) with Timer("Tracking in the diffusion voxel space"): voxel_sizes = np.asarray(dwi.header.get_zooms()[:3]) if not np.all(voxel_sizes == dwi.header.get_zooms()[0]): print("* Careful voxel are anisotropic {}!".format(tuple(voxel_sizes))) # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel. if args.step_size is not None: step_size = np.float32(args.step_size / voxel_sizes.max()) # Also convert max length (in mm) to voxel. max_nb_points = int(args.max_length / step_size) else: step_size = None max_nb_points = args.max_length if args.theta is not None: theta = np.deg2rad(args.theta) elif args.curvature is not None and args.curvature > 0: theta = get_max_angle_from_curvature(args.curvature, step_size) else: theta = np.deg2rad(45) print("Angle: {}".format(np.rad2deg(theta))) print("Step size (vox): {}".format(step_size)) print("Max nb. points: {}".format(max_nb_points)) is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold) is_too_long = make_is_too_long(max_nb_points) is_too_curvy = make_is_too_curvy(np.rad2deg(theta)) is_stopping = make_is_stopping({STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy}) is_stopping.max_nb_points = max_nb_points # Small hack tractogram = batch_track(model, weights, seeds, step_size=step_size, is_stopping=is_stopping, batch_size=args.batch_size, args=args) # Streamlines have been generated in voxel space. # Transform them them back to RAS+mm space using the dwi's affine. tractogram.affine_to_rasmm = dwi.affine tractogram.to_world() # Performed in-place. nb_streamlines = len(tractogram) print("Generated {:,} (compressed) streamlines".format(nb_streamlines)) with Timer("Cleaning streamlines", newline=True): # Flush streamlines that has no points. tractogram = tractogram[np.array(list(map(len, tractogram))) > 0] print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram))) # Remove small streamlines nb_streamlines = len(tractogram) lengths = dipy.tracking.streamline.length(tractogram.streamlines) tractogram = tractogram[lengths >= args.min_length] lengths = lengths[lengths >= args.min_length] if len(lengths) > 0: print("Average length: {:.2f} mm.".format(lengths.mean())) print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(lengths.min(), lengths.max())) print("Removed {:,} streamlines smaller than {:.2f} mm".format(nb_streamlines - len(tractogram), args.min_length)) if args.discard_stopped_by_curvature: nb_streamlines = len(tractogram) stopping_curvature_flag_is_set = is_flag_set(tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE) tractogram = tractogram[stopping_curvature_flag_is_set] print("Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree".format(nb_streamlines - len(tractogram), np.rad2deg(theta))) if args.filter_threshold is not None: # Remove streamlines that produces a reconstruction error higher than a certain threshold. nb_streamlines = len(tractogram) losses = compute_loss_errors(tractogram.streamlines, model, hyperparams) print("Mean loss: {:.4f} ± {:.4f}".format(np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses)))) tractogram = tractogram[losses <= args.filter_threshold] print("Removed {:,} streamlines producing a loss lower than {:.2f} mm".format(nb_streamlines - len(tractogram), args.filter_threshold)) with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") mask_type = args.seeds[0].replace(".", "_").replace("_", "") if "int" in args.seeds[0]: mask_type = "int" elif "wm" in args.seeds[0]: mask_type = "wm" elif "rois" in args.seeds[0]: mask_type = "rois" filename = "{}-{}_seeding-{}_step-{:.2f}mm_nbSeeds-{}_maxAngle-{:.1f}deg_keepCurv-{}_filtered-{}_minLen-{}_pftRetry-{}_pftHist-{}_useMaxComponent-{}.tck".format( os.path.basename(args.name.rstrip('/'))[:6], prefix, mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta), not args.discard_stopped_by_curvature, args.filter_threshold, args.min_length, args.pft_nb_retry, args.pft_nb_backtrack_steps, args.use_max_component ) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) nib.streamlines.save(tractogram, save_path)
def test_gru_mixture_track(): hidden_sizes = 50 with Timer("Creating dummy volume", newline=True): volume_manager = neurotools.VolumeManager() dwi, gradients = make_dummy_dwi(nb_gradients=30, volume_shape=(10, 10, 10), seed=1234) volume = neurotools.resample_dwi(dwi, gradients.bvals, gradients.bvecs).astype(np.float32) volume_manager.register(volume) with Timer("Creating model"): hyperparams = { 'model': 'gru_mixture', 'SGD': "1e-2", 'hidden_sizes': hidden_sizes, 'learn_to_stop': False, 'normalize': False, 'activation': 'tanh', 'feed_previous_direction': False, 'predict_offset': False, 'use_layer_normalization': False, 'drop_prob': 0., 'use_zoneout': False, 'skip_connections': False, 'neighborhood_radius': None, 'nb_seeds_per_voxel': 2, 'step_size': 0.5, 'batch_size': 200, 'n_gaussians': 2, 'seed': 1234 } model = factories.model_factory( hyperparams, input_size=volume_manager.data_dimension, output_size=3, volume_manager=volume_manager) model.initialize( factories.weigths_initializer_factory("orthogonal", seed=1234)) rng = np.random.RandomState(1234) mask = np.ones(volume.shape[:3]) seeding_mask = np.random.randint(2, size=mask.shape) seeds = [] indices = np.array(np.where(seeding_mask)).T for idx in indices: seeds_in_voxel = idx + rng.uniform( -0.5, 0.5, size=(hyperparams['nb_seeds_per_voxel'], 3)) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) is_outside_mask = make_is_outside_mask(mask, np.eye(4), threshold=0.5) is_too_long = make_is_too_long(150) is_too_curvy = make_is_too_curvy(np.rad2deg(30)) is_unlikely = make_is_unlikely(0.5) is_stopping = make_is_stopping({ STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy, STOPPING_LIKELIHOOD: is_unlikely }) is_stopping.max_nb_points = 150 args = SimpleNamespace() args.track_like_peter = False args.pft_nb_retry = 0 args.pft_nb_backtrack_steps = 0 args.use_max_component = False args.flip_x = False args.flip_y = False args.flip_z = False args.verbose = True tractogram = batch_track(model, volume, seeds, step_size=hyperparams['step_size'], is_stopping=is_stopping, batch_size=hyperparams['batch_size'], args=args) return True
def load_tractography_dataset_from_dwi_and_tractogram(dwi, tractogram, volume_manager, use_sh_coeffs=False, bvals=None, bvecs=None, step_size=None, mean_centering=True): # Load signal signal = nib.load(dwi) signal.get_data() # Forces loading volume in-memory. basename = re.sub('(\.gz|\.nii.gz)$', '', dwi) bvals = basename + '.bvals' if bvals is None else bvals bvecs = basename + '.bvecs' if bvecs is None else bvecs gradients = gradient_table(bvals, bvecs) tracto_data = TractographyData(signal, gradients) # Load streamlines tfile = nib.streamlines.load(tractogram) tractogram = tfile.tractogram # Resample streamline to have a fixed step size, if needed. if step_size is not None: print("Resampling streamlines to have a step size of {}mm".format( step_size)) streamlines = tractogram.streamlines streamlines._lengths = streamlines._lengths.astype(int) streamlines._offsets = streamlines._offsets.astype(int) lengths = length(streamlines) nb_points = np.ceil(lengths / step_size).astype(int) new_streamlines = (set_number_of_points(s, n) for s, n in zip(streamlines, nb_points)) tractogram = nib.streamlines.Tractogram(new_streamlines, affine_to_rasmm=np.eye(4)) # Compute matrix that brings streamlines back to diffusion voxel space. rasmm2vox_affine = np.linalg.inv(signal.affine) tractogram.apply_affine(rasmm2vox_affine) # Add streamlines to the TractogramData tracto_data.add(tractogram.streamlines, "tractogram") dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs, mean_centering=mean_centering).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs, mean_centering=mean_centering).astype( np.float32) tracto_data.signal.uncache( ) # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id return TractographyDataset([tracto_data], "dataset", keep_on_cpu=True)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) with Timer("Loading model"): if hyperparams["model"] == "ffnn_classification": from learn2track.models import FFNN_Classification model_class = FFNN_Classification else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create( pjoin(experiment_path), **kwargs) # Create new instance and restore model. print(str(model)) with Timer("Generating mask"): symb_input = T.matrix(name="input") model_symb_pred = model.get_output(symb_input) f = theano.function(inputs=[symb_input], outputs=[model_symb_pred]) generated_mask = np.zeros(dwi.shape[:3]).astype(np.float32) # all_coords.shape = (n_coords, 3) all_coords = np.argwhere(generated_mask == 0) volume_ids = np.zeros((all_coords.shape[0], 1)) all_coords_and_volume_ids = np.concatenate((all_coords, volume_ids), axis=1).astype(np.float32) batch_size = args.batch_size if args.batch_size else len( all_coords_and_volume_ids) probs = [] while batch_size > 1: print("Trying to to process batches of size {} out of {}".format( batch_size, len(all_coords_and_volume_ids))) nb_batches = int( np.ceil(len(all_coords_and_volume_ids) / batch_size)) try: for batch_count in range(nb_batches): start = batch_count * batch_size end = (batch_count + 1) * batch_size probs.extend(f(all_coords_and_volume_ids[start:end])[-1]) print("Generated batch {} out of {}".format( batch_count + 1, nb_batches)) break except MemoryError: print("{} coordinates at the same time is too much!".format( batch_size)) batch_size //= 2 except RuntimeError: print("{} coordinates at the same time is too much!".format( batch_size)) batch_size //= 2 if not probs: raise RuntimeError("Could not generate predictions...") generated_mask[np.where(generated_mask == 0)] = np.array(probs) > 0.5 with Timer("Saving generated mask"): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") filename = "{}.nii.gz".format(prefix) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) mask = nib.Nifti1Image(generated_mask, dwi.affine) nib.save(mask, save_path)