def load_mask_classifier_dataset(subject_files, volume_manager, name="HCP", use_sh_coeffs=False): subjects = [] with Timer(" Loading subject(s)", newline=True): for subject_file in sorted(subject_files): print(" {}".format(subject_file)) mask_data = MaskClassifierData.load(subject_file) dwi = mask_data.signal bvals = mask_data.gradients.bvals bvecs = mask_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) mask_data.signal.uncache( ) # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) mask_data.subject_id = subject_id subjects.append(mask_data) return MaskClassifierDataset(subjects, name, keep_on_cpu=True)
def load_tractography_dataset(subject_files, volume_manager, name="HCP", use_sh_coeffs=False): subjects = [] with Timer(" Loading subject(s)", newline=True): for subject_file in sorted(subject_files): print(" {}".format(subject_file)) tracto_data = TractographyData.load(subject_file) dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) tracto_data.signal.uncache() # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id subjects.append(tracto_data) return TractographyDataset(subjects, name, keep_on_cpu=True)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] try: bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError: try: bvals_filename = dwi_name + ".bval" bvecs_filename = dwi_name + ".bvec" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) except FileNotFoundError as e: print("Could not find .bvals/.bvecs or .bval/.bvec files...") raise e dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) affine_rasmm2dwivox = np.linalg.inv(dwi.affine) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model_class = GRU_Regression elif hyperparams['model'] == 'gru_gaussian': from learn2track.models import GRU_Gaussian model_class = GRU_Gaussian elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture model_class = GRU_Mixture elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian model_class = GRU_Multistep_Gaussian elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression model_class = FFNN_Regression else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create( pjoin(experiment_path), **kwargs) # Create new instance and restore model. model.drop_prob = 0. print(str(model)) mask = None if args.mask is not None: with Timer("Loading mask"): mask_nii = nib.load(args.mask) mask = mask_nii.get_data() # Compute the affine allowing to evaluate the mask at some coordinates correctly. # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine) if args.dilate_mask: import scipy mask = scipy.ndimage.morphology.binary_dilation(mask).astype( mask.dtype) with Timer("Generating seeds"): seeds = [] for filename in args.seeds: if filename.endswith('.trk') or filename.endswith('.tck'): tfile = nib.streamlines.load(filename) # Send the streamlines to voxel since that's where we'll track. tfile.tractogram.apply_affine(affine_rasmm2dwivox) # Use extremities of the streamlines as seeding points. seeds += [s[0] for s in tfile.streamlines] seeds += [s[-1] for s in tfile.streamlines] else: # Assume it is a binary mask. rng = np.random.RandomState(args.seeding_rng_seed) nii_seeds = nib.load(filename) # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine) nii_seeds_data = nii_seeds.get_data() if args.dilate_seeding_mask: import scipy nii_seeds_data = scipy.ndimage.morphology.binary_dilation( nii_seeds_data).astype(nii_seeds_data.dtype) indices = np.array(np.where(nii_seeds_data)).T for idx in indices: seeds_in_voxel = idx + rng.uniform( -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3)) seeds_in_voxel = nib.affines.apply_affine( affine_seedsvox2dwivox, seeds_in_voxel) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) with Timer("Tracking in the diffusion voxel space"): voxel_sizes = np.asarray(dwi.header.get_zooms()[:3]) if not np.all(voxel_sizes == dwi.header.get_zooms()[0]): print("* Careful voxel are anisotropic {}!".format( tuple(voxel_sizes))) # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel. if args.step_size is not None: step_size = np.float32(args.step_size / voxel_sizes.max()) # Also convert max length (in mm) to voxel. max_nb_points = int(np.ceil(args.max_length / args.step_size)) else: step_size = None max_nb_points = args.max_length if args.theta is not None: theta = np.deg2rad(args.theta) elif args.curvature is not None and args.curvature > 0: theta = get_max_angle_from_curvature(args.curvature, step_size) else: theta = np.deg2rad(45) print("Angle: {}".format(np.rad2deg(theta))) print("Step size (vox): {}".format(step_size)) print("Max nb. points: {}".format(max_nb_points)) is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold) is_too_long = make_is_too_long(max_nb_points) is_too_curvy = make_is_too_curvy(np.rad2deg(theta)) is_unlikely = make_is_unlikely(0.5) is_stopping = make_is_stopping({ STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy, STOPPING_LIKELIHOOD: is_unlikely }) is_stopping.max_nb_points = max_nb_points # Small hack tractogram = batch_track(model, weights, seeds, step_size=step_size, is_stopping=is_stopping, batch_size=args.batch_size, args=args) # Streamlines have been generated in voxel space. # Transform them them back to RAS+mm space using the dwi's affine. tractogram.affine_to_rasmm = dwi.affine tractogram.to_world() # Performed in-place. nb_streamlines = len(tractogram) if args.save_rejected: rejected_tractogram = Tractogram() rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm print("Generated {:,} (compressed) streamlines".format(nb_streamlines)) with Timer("Cleaning streamlines", newline=True): # Flush streamlines that have no points. if args.save_rejected: rejected_tractogram += tractogram[ np.array(list(map(len, tractogram))) <= 0] tractogram = tractogram[np.array(list(map(len, tractogram))) > 0] print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram))) # Remove small streamlines nb_streamlines = len(tractogram) lengths = dipy.tracking.streamline.length(tractogram.streamlines) if args.save_rejected: rejected_tractogram += tractogram[lengths < args.min_length] tractogram = tractogram[lengths >= args.min_length] lengths = lengths[lengths >= args.min_length] if len(lengths) > 0: print("Average length: {:.2f} mm.".format(lengths.mean())) print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format( lengths.min(), lengths.max())) print("Removed {:,} streamlines smaller than {:.2f} mm".format( nb_streamlines - len(tractogram), args.min_length)) if args.discard_stopped_by_curvature: nb_streamlines = len(tractogram) stopping_curvature_flag_is_set = is_flag_set( tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE) if args.save_rejected: rejected_tractogram += tractogram[ stopping_curvature_flag_is_set] tractogram = tractogram[np.logical_not( stopping_curvature_flag_is_set)] print( "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree" .format(nb_streamlines - len(tractogram), np.rad2deg(theta))) if args.filter_threshold is not None: # Remove streamlines that produces a reconstruction error higher than a certain threshold. nb_streamlines = len(tractogram) losses = compute_loss_errors(tractogram.streamlines, model, hyperparams) print("Mean loss: {:.4f} ± {:.4f}".format( np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses)))) if args.save_rejected: rejected_tractogram += tractogram[ losses > args.filter_threshold] tractogram = tractogram[losses <= args.filter_threshold] print( "Removed {:,} streamlines producing a loss lower than {:.2f} mm" .format(nb_streamlines - len(tractogram), args.filter_threshold)) with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") seed_mask_type = args.seeds[0].replace(".", "_").replace( "_", "").replace("/", "-") if "int" in args.seeds[0]: seed_mask_type = "int" elif "wm" in args.seeds[0]: seed_mask_type = "wm" elif "rois" in args.seeds[0]: seed_mask_type = "rois" elif "bundles" in args.seeds[0]: seed_mask_type = "bundles" mask_type = "" if "fa" in args.mask: mask_type = "fa" elif "wm" in args.mask: mask_type = "wm" if args.dilate_seeding_mask: seed_mask_type += "D" if args.dilate_mask: mask_type += "D" filename_items = [ "{}", "useMaxComponent-{}", # "seed-{}", # "mask-{}", "step-{:.2f}mm", "nbSeeds-{}", "maxAngleDeg-{:.1f}" # "keepCurv-{}", # "filtered-{}", # "minLen-{}", # "pftRetry-{}", # "pftHist-{}", # "trackLikePeter-{}", ] filename = ('_'.join(filename_items) + ".tck").format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) nib.streamlines.save(tractogram, save_path) if args.save_rejected: with Timer("Saving {:,} (compressed) rejected streamlines".format( len(rejected_tractogram))): rejected_filename_items = filename_items.copy() rejected_filename_items.insert(1, "rejected") rejected_filename = ( '_'.join(rejected_filename_items) + ".tck" ).format( prefix, args.use_max_component, # seed_mask_type, # mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta) # not args.discard_stopped_by_curvature, # args.filter_threshold, # args.min_length, # args.pft_nb_retry, # args.pft_nb_backtrack_steps, # args.track_like_peter ) rejected_save_path = pjoin(experiment_path, rejected_filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(rejected_save_path)) except: pass print("Saving rejected streamlines to {}".format(rejected_save_path)) nib.streamlines.save(rejected_tractogram, rejected_save_path)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(bvals_filename, bvecs_filename) dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients(dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) affine_rasmm2dwivox = np.linalg.inv(dwi.affine) with Timer("Loading model"): if hyperparams["model"] == "gru_regression": from learn2track.models import GRU_Regression model_class = GRU_Regression elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture model_class = GRU_Mixture elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian model_class = GRU_Multistep_Gaussian elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression model_class = FFNN_Regression else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create(pjoin(experiment_path), **kwargs) # Create new instance and restore model. print(str(model)) mask = None if args.mask is not None: with Timer("Loading mask"): mask_nii = nib.load(args.mask) mask = mask_nii.get_data() # Compute the affine allowing to evaluate the mask at some coordinates correctly. # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox, mask_nii.affine) if args.dilate_mask: import scipy mask = scipy.ndimage.morphology.binary_dilation(mask).astype(mask.dtype) with Timer("Generating seeds"): seeds = [] for filename in args.seeds: if filename.endswith('.trk') or filename.endswith('.tck'): tfile = nib.streamlines.load(filename) # Send the streamlines to voxel since that's where we'll track. tfile.tractogram.apply_affine(affine_rasmm2dwivox) # Use extremities of the streamlines as seeding points. seeds += [s[0] for s in tfile.streamlines] seeds += [s[-1] for s in tfile.streamlines] else: # Assume it is a binary mask. rng = np.random.RandomState(args.seeding_rng_seed) nii_seeds = nib.load(filename) # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox, nii_seeds.affine) indices = np.array(np.where(nii_seeds.get_data())).T for idx in indices: seeds_in_voxel = idx + rng.uniform(-0.5, 0.5, size=(args.nb_seeds_per_voxel, 3)) seeds_in_voxel = nib.affines.apply_affine(affine_seedsvox2dwivox, seeds_in_voxel) seeds.extend(seeds_in_voxel) seeds = np.array(seeds, dtype=theano.config.floatX) with Timer("Tracking in the diffusion voxel space"): voxel_sizes = np.asarray(dwi.header.get_zooms()[:3]) if not np.all(voxel_sizes == dwi.header.get_zooms()[0]): print("* Careful voxel are anisotropic {}!".format(tuple(voxel_sizes))) # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel. if args.step_size is not None: step_size = np.float32(args.step_size / voxel_sizes.max()) # Also convert max length (in mm) to voxel. max_nb_points = int(args.max_length / step_size) else: step_size = None max_nb_points = args.max_length if args.theta is not None: theta = np.deg2rad(args.theta) elif args.curvature is not None and args.curvature > 0: theta = get_max_angle_from_curvature(args.curvature, step_size) else: theta = np.deg2rad(45) print("Angle: {}".format(np.rad2deg(theta))) print("Step size (vox): {}".format(step_size)) print("Max nb. points: {}".format(max_nb_points)) is_outside_mask = make_is_outside_mask(mask, affine_maskvox2dwivox, threshold=args.mask_threshold) is_too_long = make_is_too_long(max_nb_points) is_too_curvy = make_is_too_curvy(np.rad2deg(theta)) is_stopping = make_is_stopping({STOPPING_MASK: is_outside_mask, STOPPING_LENGTH: is_too_long, STOPPING_CURVATURE: is_too_curvy}) is_stopping.max_nb_points = max_nb_points # Small hack tractogram = batch_track(model, weights, seeds, step_size=step_size, is_stopping=is_stopping, batch_size=args.batch_size, args=args) # Streamlines have been generated in voxel space. # Transform them them back to RAS+mm space using the dwi's affine. tractogram.affine_to_rasmm = dwi.affine tractogram.to_world() # Performed in-place. nb_streamlines = len(tractogram) print("Generated {:,} (compressed) streamlines".format(nb_streamlines)) with Timer("Cleaning streamlines", newline=True): # Flush streamlines that has no points. tractogram = tractogram[np.array(list(map(len, tractogram))) > 0] print("Removed {:,} empty streamlines".format(nb_streamlines - len(tractogram))) # Remove small streamlines nb_streamlines = len(tractogram) lengths = dipy.tracking.streamline.length(tractogram.streamlines) tractogram = tractogram[lengths >= args.min_length] lengths = lengths[lengths >= args.min_length] if len(lengths) > 0: print("Average length: {:.2f} mm.".format(lengths.mean())) print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(lengths.min(), lengths.max())) print("Removed {:,} streamlines smaller than {:.2f} mm".format(nb_streamlines - len(tractogram), args.min_length)) if args.discard_stopped_by_curvature: nb_streamlines = len(tractogram) stopping_curvature_flag_is_set = is_flag_set(tractogram.data_per_streamline['stopping_flags'][:, 0], STOPPING_CURVATURE) tractogram = tractogram[stopping_curvature_flag_is_set] print("Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree".format(nb_streamlines - len(tractogram), np.rad2deg(theta))) if args.filter_threshold is not None: # Remove streamlines that produces a reconstruction error higher than a certain threshold. nb_streamlines = len(tractogram) losses = compute_loss_errors(tractogram.streamlines, model, hyperparams) print("Mean loss: {:.4f} ± {:.4f}".format(np.mean(losses), np.std(losses, ddof=1) / np.sqrt(len(losses)))) tractogram = tractogram[losses <= args.filter_threshold] print("Removed {:,} streamlines producing a loss lower than {:.2f} mm".format(nb_streamlines - len(tractogram), args.filter_threshold)) with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") mask_type = args.seeds[0].replace(".", "_").replace("_", "") if "int" in args.seeds[0]: mask_type = "int" elif "wm" in args.seeds[0]: mask_type = "wm" elif "rois" in args.seeds[0]: mask_type = "rois" filename = "{}-{}_seeding-{}_step-{:.2f}mm_nbSeeds-{}_maxAngle-{:.1f}deg_keepCurv-{}_filtered-{}_minLen-{}_pftRetry-{}_pftHist-{}_useMaxComponent-{}.tck".format( os.path.basename(args.name.rstrip('/'))[:6], prefix, mask_type, args.step_size, args.nb_seeds_per_voxel, np.rad2deg(theta), not args.discard_stopped_by_curvature, args.filter_threshold, args.min_length, args.pft_nb_retry, args.pft_nb_backtrack_steps, args.use_max_component ) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) nib.streamlines.save(tractogram, save_path)
def load_tractography_dataset_from_dwi_and_tractogram(dwi, tractogram, volume_manager, use_sh_coeffs=False, bvals=None, bvecs=None, step_size=None, mean_centering=True): # Load signal signal = nib.load(dwi) signal.get_data() # Forces loading volume in-memory. basename = re.sub('(\.gz|\.nii.gz)$', '', dwi) bvals = basename + '.bvals' if bvals is None else bvals bvecs = basename + '.bvecs' if bvecs is None else bvecs gradients = gradient_table(bvals, bvecs) tracto_data = TractographyData(signal, gradients) # Load streamlines tfile = nib.streamlines.load(tractogram) tractogram = tfile.tractogram # Resample streamline to have a fixed step size, if needed. if step_size is not None: print("Resampling streamlines to have a step size of {}mm".format( step_size)) streamlines = tractogram.streamlines streamlines._lengths = streamlines._lengths.astype(int) streamlines._offsets = streamlines._offsets.astype(int) lengths = length(streamlines) nb_points = np.ceil(lengths / step_size).astype(int) new_streamlines = (set_number_of_points(s, n) for s, n in zip(streamlines, nb_points)) tractogram = nib.streamlines.Tractogram(new_streamlines, affine_to_rasmm=np.eye(4)) # Compute matrix that brings streamlines back to diffusion voxel space. rasmm2vox_affine = np.linalg.inv(signal.affine) tractogram.apply_affine(rasmm2vox_affine) # Add streamlines to the TractogramData tracto_data.add(tractogram.streamlines, "tractogram") dwi = tracto_data.signal bvals = tracto_data.gradients.bvals bvecs = tracto_data.gradients.bvecs if use_sh_coeffs: # Use 45 spherical harmonic coefficients to represent the diffusion signal. volume = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs, mean_centering=mean_centering).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. volume = neurotools.resample_dwi(dwi, bvals, bvecs, mean_centering=mean_centering).astype( np.float32) tracto_data.signal.uncache( ) # Free some memory as we don't need the original signal. subject_id = volume_manager.register(volume) tracto_data.subject_id = subject_id return TractographyDataset([tracto_data], "dataset", keep_on_cpu=True)
def main(): parser = build_argparser() args = parser.parse_args() # Get experiment folder experiment_path = args.name if not os.path.isdir(experiment_path): # If not a directory, it must be the name of the experiment. experiment_path = pjoin(".", "experiments", args.name) if not os.path.isdir(experiment_path): parser.error('Cannot find experiment: {0}!'.format(args.name)) # Load experiments hyperparameters try: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "hyperparams.json")) except FileNotFoundError: hyperparams = smartutils.load_dict_from_json_file( pjoin(experiment_path, "..", "hyperparams.json")) with Timer("Loading DWIs"): # Load gradients table dwi_name = args.dwi if dwi_name.endswith(".gz"): dwi_name = dwi_name[:-3] if dwi_name.endswith(".nii"): dwi_name = dwi_name[:-4] bvals_filename = dwi_name + ".bvals" bvecs_filename = dwi_name + ".bvecs" bvals, bvecs = dipy.io.gradients.read_bvals_bvecs( bvals_filename, bvecs_filename) dwi = nib.load(args.dwi) if hyperparams["use_sh_coeffs"]: # Use 45 spherical harmonic coefficients to represent the diffusion signal. weights = neurotools.get_spherical_harmonics_coefficients( dwi, bvals, bvecs).astype(np.float32) else: # Resample the diffusion signal to have 100 directions. weights = neurotools.resample_dwi(dwi, bvals, bvecs).astype(np.float32) with Timer("Loading model"): if hyperparams["model"] == "ffnn_classification": from learn2track.models import FFNN_Classification model_class = FFNN_Classification else: raise ValueError("Unknown model!") kwargs = {} volume_manager = neurotools.VolumeManager() volume_manager.register(weights) kwargs['volume_manager'] = volume_manager # Load the actual model. model = model_class.create( pjoin(experiment_path), **kwargs) # Create new instance and restore model. print(str(model)) with Timer("Generating mask"): symb_input = T.matrix(name="input") model_symb_pred = model.get_output(symb_input) f = theano.function(inputs=[symb_input], outputs=[model_symb_pred]) generated_mask = np.zeros(dwi.shape[:3]).astype(np.float32) # all_coords.shape = (n_coords, 3) all_coords = np.argwhere(generated_mask == 0) volume_ids = np.zeros((all_coords.shape[0], 1)) all_coords_and_volume_ids = np.concatenate((all_coords, volume_ids), axis=1).astype(np.float32) batch_size = args.batch_size if args.batch_size else len( all_coords_and_volume_ids) probs = [] while batch_size > 1: print("Trying to to process batches of size {} out of {}".format( batch_size, len(all_coords_and_volume_ids))) nb_batches = int( np.ceil(len(all_coords_and_volume_ids) / batch_size)) try: for batch_count in range(nb_batches): start = batch_count * batch_size end = (batch_count + 1) * batch_size probs.extend(f(all_coords_and_volume_ids[start:end])[-1]) print("Generated batch {} out of {}".format( batch_count + 1, nb_batches)) break except MemoryError: print("{} coordinates at the same time is too much!".format( batch_size)) batch_size //= 2 except RuntimeError: print("{} coordinates at the same time is too much!".format( batch_size)) batch_size //= 2 if not probs: raise RuntimeError("Could not generate predictions...") generated_mask[np.where(generated_mask == 0)] = np.array(probs) > 0.5 with Timer("Saving generated mask"): filename = args.out if args.out is None: prefix = args.prefix if prefix is None: dwi_name = os.path.basename(args.dwi) if dwi_name.endswith(".nii.gz"): dwi_name = dwi_name[:-7] else: # .nii dwi_name = dwi_name[:-4] prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name prefix = prefix.replace(".", "_") filename = "{}.nii.gz".format(prefix) save_path = pjoin(experiment_path, filename) try: # Create dirs, if needed. os.makedirs(os.path.dirname(save_path)) except: pass print("Saving to {}".format(save_path)) mask = nib.Nifti1Image(generated_mask, dwi.affine) nib.save(mask, save_path)