Пример #1
0
def test_gru_mixture_fprop():
    hidden_sizes = 50

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = batch_schedulers.TractographyBatchScheduler(
            trainset, batch_size=16, noisy_streamlines_sigma=None, seed=1234)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hidden_sizes,
              batch_scheduler.target_size)

    with Timer("Creating model"):
        hyperparams = {
            'model': 'gru_mixture',
            'n_gaussians': 2,
            'SGD': "1e-2",
            'hidden_sizes': hidden_sizes,
            'learn_to_stop': False,
            'normalize': False,
            'feed_previous_direction': False
        }
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    # Test fprop with missing streamlines from one subject in a batch
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)
Пример #2
0
def test_gru_multistep_fprop_k3():
    hidden_sizes = 50

    hyperparams = {
        'model': 'gru_multistep',
        'k': 3,
        'm': 3,
        'batch_size': 16,
        'SGD': "1e-2",
        'hidden_sizes': hidden_sizes,
        'learn_to_stop': False,
        'normalize': False,
        'noisy_streamlines_sigma': None,
        'shuffle_streamlines': True,
        'seed': 1234
    }

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = factories.batch_scheduler_factory(hyperparams,
                                                            trainset,
                                                            train_mode=True)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hidden_sizes,
              batch_scheduler.target_size)

    with Timer("Creating model"):
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    # Test fprop
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)
Пример #3
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]

        try:
            bvals_filename = dwi_name + ".bvals"
            bvecs_filename = dwi_name + ".bvecs"
            bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                bvals_filename, bvecs_filename)
        except FileNotFoundError:
            try:
                bvals_filename = dwi_name + ".bval"
                bvecs_filename = dwi_name + ".bvec"
                bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                    bvals_filename, bvecs_filename)
            except FileNotFoundError as e:
                print("Could not find .bvals/.bvecs or .bval/.bvec files...")
                raise e

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

        affine_rasmm2dwivox = np.linalg.inv(dwi.affine)

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression
            model_class = GRU_Regression
        elif hyperparams['model'] == 'gru_gaussian':
            from learn2track.models import GRU_Gaussian
            model_class = GRU_Gaussian
        elif hyperparams['model'] == 'gru_mixture':
            from learn2track.models import GRU_Mixture
            model_class = GRU_Mixture
        elif hyperparams['model'] == 'gru_multistep':
            from learn2track.models import GRU_Multistep_Gaussian
            model_class = GRU_Multistep_Gaussian
        elif hyperparams['model'] == 'ffnn_regression':
            from learn2track.models import FFNN_Regression
            model_class = FFNN_Regression
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        model.drop_prob = 0.
        print(str(model))

    mask = None
    if args.mask is not None:
        with Timer("Loading mask"):
            mask_nii = nib.load(args.mask)
            mask = mask_nii.get_data()
            # Compute the affine allowing to evaluate the mask at some coordinates correctly.

            # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox
            affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox,
                                           mask_nii.affine)
            if args.dilate_mask:
                import scipy
                mask = scipy.ndimage.morphology.binary_dilation(mask).astype(
                    mask.dtype)

    with Timer("Generating seeds"):
        seeds = []

        for filename in args.seeds:
            if filename.endswith('.trk') or filename.endswith('.tck'):
                tfile = nib.streamlines.load(filename)
                # Send the streamlines to voxel since that's where we'll track.
                tfile.tractogram.apply_affine(affine_rasmm2dwivox)

                # Use extremities of the streamlines as seeding points.
                seeds += [s[0] for s in tfile.streamlines]
                seeds += [s[-1] for s in tfile.streamlines]

            else:
                # Assume it is a binary mask.
                rng = np.random.RandomState(args.seeding_rng_seed)
                nii_seeds = nib.load(filename)

                # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox
                affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox,
                                                nii_seeds.affine)

                nii_seeds_data = nii_seeds.get_data()

                if args.dilate_seeding_mask:
                    import scipy
                    nii_seeds_data = scipy.ndimage.morphology.binary_dilation(
                        nii_seeds_data).astype(nii_seeds_data.dtype)

                indices = np.array(np.where(nii_seeds_data)).T
                for idx in indices:
                    seeds_in_voxel = idx + rng.uniform(
                        -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3))
                    seeds_in_voxel = nib.affines.apply_affine(
                        affine_seedsvox2dwivox, seeds_in_voxel)
                    seeds.extend(seeds_in_voxel)

        seeds = np.array(seeds, dtype=theano.config.floatX)

    with Timer("Tracking in the diffusion voxel space"):
        voxel_sizes = np.asarray(dwi.header.get_zooms()[:3])
        if not np.all(voxel_sizes == dwi.header.get_zooms()[0]):
            print("* Careful voxel are anisotropic {}!".format(
                tuple(voxel_sizes)))
        # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel.

        if args.step_size is not None:
            step_size = np.float32(args.step_size / voxel_sizes.max())
            # Also convert max length (in mm) to voxel.
            max_nb_points = int(np.ceil(args.max_length / args.step_size))
        else:
            step_size = None
            max_nb_points = args.max_length

        if args.theta is not None:
            theta = np.deg2rad(args.theta)
        elif args.curvature is not None and args.curvature > 0:
            theta = get_max_angle_from_curvature(args.curvature, step_size)
        else:
            theta = np.deg2rad(45)

        print("Angle: {}".format(np.rad2deg(theta)))
        print("Step size (vox): {}".format(step_size))
        print("Max nb. points: {}".format(max_nb_points))

        is_outside_mask = make_is_outside_mask(mask,
                                               affine_maskvox2dwivox,
                                               threshold=args.mask_threshold)
        is_too_long = make_is_too_long(max_nb_points)
        is_too_curvy = make_is_too_curvy(np.rad2deg(theta))
        is_unlikely = make_is_unlikely(0.5)
        is_stopping = make_is_stopping({
            STOPPING_MASK: is_outside_mask,
            STOPPING_LENGTH: is_too_long,
            STOPPING_CURVATURE: is_too_curvy,
            STOPPING_LIKELIHOOD: is_unlikely
        })

        is_stopping.max_nb_points = max_nb_points  # Small hack

        tractogram = batch_track(model,
                                 weights,
                                 seeds,
                                 step_size=step_size,
                                 is_stopping=is_stopping,
                                 batch_size=args.batch_size,
                                 args=args)

        # Streamlines have been generated in voxel space.
        # Transform them them back to RAS+mm space using the dwi's affine.
        tractogram.affine_to_rasmm = dwi.affine
        tractogram.to_world()  # Performed in-place.

    nb_streamlines = len(tractogram)

    if args.save_rejected:
        rejected_tractogram = Tractogram()
        rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm

    print("Generated {:,} (compressed) streamlines".format(nb_streamlines))
    with Timer("Cleaning streamlines", newline=True):
        # Flush streamlines that have no points.
        if args.save_rejected:
            rejected_tractogram += tractogram[
                np.array(list(map(len, tractogram))) <= 0]

        tractogram = tractogram[np.array(list(map(len, tractogram))) > 0]
        print("Removed {:,} empty streamlines".format(nb_streamlines -
                                                      len(tractogram)))

        # Remove small streamlines
        nb_streamlines = len(tractogram)
        lengths = dipy.tracking.streamline.length(tractogram.streamlines)

        if args.save_rejected:
            rejected_tractogram += tractogram[lengths < args.min_length]

        tractogram = tractogram[lengths >= args.min_length]
        lengths = lengths[lengths >= args.min_length]
        if len(lengths) > 0:
            print("Average length: {:.2f} mm.".format(lengths.mean()))
            print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(
                lengths.min(), lengths.max()))
        print("Removed {:,} streamlines smaller than {:.2f} mm".format(
            nb_streamlines - len(tractogram), args.min_length))
        if args.discard_stopped_by_curvature:
            nb_streamlines = len(tractogram)
            stopping_curvature_flag_is_set = is_flag_set(
                tractogram.data_per_streamline['stopping_flags'][:, 0],
                STOPPING_CURVATURE)

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    stopping_curvature_flag_is_set]

            tractogram = tractogram[np.logical_not(
                stopping_curvature_flag_is_set)]
            print(
                "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree"
                .format(nb_streamlines - len(tractogram), np.rad2deg(theta)))

        if args.filter_threshold is not None:
            # Remove streamlines that produces a reconstruction error higher than a certain threshold.
            nb_streamlines = len(tractogram)
            losses = compute_loss_errors(tractogram.streamlines, model,
                                         hyperparams)
            print("Mean loss: {:.4f} ± {:.4f}".format(
                np.mean(losses),
                np.std(losses, ddof=1) / np.sqrt(len(losses))))

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    losses > args.filter_threshold]

            tractogram = tractogram[losses <= args.filter_threshold]
            print(
                "Removed {:,} streamlines producing a loss lower than {:.2f} mm"
                .format(nb_streamlines - len(tractogram),
                        args.filter_threshold))

    with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            seed_mask_type = args.seeds[0].replace(".", "_").replace(
                "_", "").replace("/", "-")
            if "int" in args.seeds[0]:
                seed_mask_type = "int"
            elif "wm" in args.seeds[0]:
                seed_mask_type = "wm"
            elif "rois" in args.seeds[0]:
                seed_mask_type = "rois"
            elif "bundles" in args.seeds[0]:
                seed_mask_type = "bundles"

            mask_type = ""
            if "fa" in args.mask:
                mask_type = "fa"
            elif "wm" in args.mask:
                mask_type = "wm"

            if args.dilate_seeding_mask:
                seed_mask_type += "D"

            if args.dilate_mask:
                mask_type += "D"

            filename_items = [
                "{}",
                "useMaxComponent-{}",
                # "seed-{}",
                # "mask-{}",
                "step-{:.2f}mm",
                "nbSeeds-{}",
                "maxAngleDeg-{:.1f}"
                # "keepCurv-{}",
                # "filtered-{}",
                # "minLen-{}",
                # "pftRetry-{}",
                # "pftHist-{}",
                # "trackLikePeter-{}",
            ]
            filename = ('_'.join(filename_items) + ".tck").format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        nib.streamlines.save(tractogram, save_path)

    if args.save_rejected:
        with Timer("Saving {:,} (compressed) rejected streamlines".format(
                len(rejected_tractogram))):
            rejected_filename_items = filename_items.copy()
            rejected_filename_items.insert(1, "rejected")
            rejected_filename = (
                '_'.join(rejected_filename_items) + ".tck"
            ).format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        rejected_save_path = pjoin(experiment_path, rejected_filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(rejected_save_path))
        except:
            pass

        print("Saving rejected streamlines to {}".format(rejected_save_path))
        nib.streamlines.save(rejected_tractogram, rejected_save_path)
Пример #4
0
def test_gru_mixture_track():
    hidden_sizes = 50

    with Timer("Creating dummy volume", newline=True):
        volume_manager = neurotools.VolumeManager()
        dwi, gradients = make_dummy_dwi(nb_gradients=30,
                                        volume_shape=(10, 10, 10),
                                        seed=1234)
        volume = neurotools.resample_dwi(dwi, gradients.bvals,
                                         gradients.bvecs).astype(np.float32)

        volume_manager.register(volume)

    with Timer("Creating model"):
        hyperparams = {
            'model': 'gru_mixture',
            'SGD': "1e-2",
            'hidden_sizes': hidden_sizes,
            'learn_to_stop': False,
            'normalize': False,
            'activation': 'tanh',
            'feed_previous_direction': False,
            'predict_offset': False,
            'use_layer_normalization': False,
            'drop_prob': 0.,
            'use_zoneout': False,
            'skip_connections': False,
            'neighborhood_radius': None,
            'nb_seeds_per_voxel': 2,
            'step_size': 0.5,
            'batch_size': 200,
            'n_gaussians': 2,
            'seed': 1234
        }
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=3,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    rng = np.random.RandomState(1234)
    mask = np.ones(volume.shape[:3])
    seeding_mask = np.random.randint(2, size=mask.shape)
    seeds = []
    indices = np.array(np.where(seeding_mask)).T
    for idx in indices:
        seeds_in_voxel = idx + rng.uniform(
            -0.5, 0.5, size=(hyperparams['nb_seeds_per_voxel'], 3))
        seeds.extend(seeds_in_voxel)
    seeds = np.array(seeds, dtype=theano.config.floatX)

    is_outside_mask = make_is_outside_mask(mask, np.eye(4), threshold=0.5)
    is_too_long = make_is_too_long(150)
    is_too_curvy = make_is_too_curvy(np.rad2deg(30))
    is_unlikely = make_is_unlikely(0.5)
    is_stopping = make_is_stopping({
        STOPPING_MASK: is_outside_mask,
        STOPPING_LENGTH: is_too_long,
        STOPPING_CURVATURE: is_too_curvy,
        STOPPING_LIKELIHOOD: is_unlikely
    })
    is_stopping.max_nb_points = 150

    args = SimpleNamespace()
    args.track_like_peter = False
    args.pft_nb_retry = 0
    args.pft_nb_backtrack_steps = 0
    args.use_max_component = False
    args.flip_x = False
    args.flip_y = False
    args.flip_z = False
    args.verbose = True

    tractogram = batch_track(model,
                             volume,
                             seeds,
                             step_size=hyperparams['step_size'],
                             is_stopping=is_stopping,
                             batch_size=hyperparams['batch_size'],
                             args=args)

    return True
Пример #5
0
def test_gru_mixture_fprop_neighborhood():
    hyperparams = {
        'model': 'gru_mixture',
        'SGD': "1e-2",
        'hidden_sizes': 50,
        'batch_size': 16,
        'learn_to_stop': False,
        'normalize': True,
        'activation': 'tanh',
        'feed_previous_direction': False,
        'predict_offset': False,
        'use_layer_normalization': False,
        'drop_prob': 0.,
        'use_zoneout': False,
        'skip_connections': False,
        'seed': 1234,
        'noisy_streamlines_sigma': None,
        'keep_step_size': True,
        'sort_streamlines': False,
        'n_gaussians': 2,
        'neighborhood_radius': 0.5
    }

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = factories.batch_scheduler_factory(hyperparams,
                                                            dataset=trainset)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hyperparams['hidden_sizes'],
              batch_scheduler.target_size)

    with Timer("Creating model"):
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

        print("Input size: {}".format(model.model_input_size))

    # Test fprop with missing streamlines from one subject in a batch
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)

    return True
Пример #6
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]
        bvals_filename = dwi_name + ".bvals"
        bvecs_filename = dwi_name + ".bvecs"
        bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
            bvals_filename, bvecs_filename)

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

    with Timer("Loading model"):
        if hyperparams["model"] == "ffnn_classification":
            from learn2track.models import FFNN_Classification
            model_class = FFNN_Classification
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        print(str(model))

    with Timer("Generating mask"):
        symb_input = T.matrix(name="input")
        model_symb_pred = model.get_output(symb_input)
        f = theano.function(inputs=[symb_input], outputs=[model_symb_pred])

        generated_mask = np.zeros(dwi.shape[:3]).astype(np.float32)

        # all_coords.shape = (n_coords, 3)
        all_coords = np.argwhere(generated_mask == 0)

        volume_ids = np.zeros((all_coords.shape[0], 1))
        all_coords_and_volume_ids = np.concatenate((all_coords, volume_ids),
                                                   axis=1).astype(np.float32)

        batch_size = args.batch_size if args.batch_size else len(
            all_coords_and_volume_ids)
        probs = []
        while batch_size > 1:
            print("Trying to to process batches of size {} out of {}".format(
                batch_size, len(all_coords_and_volume_ids)))
            nb_batches = int(
                np.ceil(len(all_coords_and_volume_ids) / batch_size))
            try:
                for batch_count in range(nb_batches):
                    start = batch_count * batch_size
                    end = (batch_count + 1) * batch_size
                    probs.extend(f(all_coords_and_volume_ids[start:end])[-1])
                    print("Generated batch {} out of {}".format(
                        batch_count + 1, nb_batches))
                break
            except MemoryError:
                print("{} coordinates at the same time is too much!".format(
                    batch_size))
                batch_size //= 2
            except RuntimeError:
                print("{} coordinates at the same time is too much!".format(
                    batch_size))
                batch_size //= 2
        if not probs:
            raise RuntimeError("Could not generate predictions...")

        generated_mask[np.where(generated_mask == 0)] = np.array(probs) > 0.5

    with Timer("Saving generated mask"):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            filename = "{}.nii.gz".format(prefix)

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        mask = nib.Nifti1Image(generated_mask, dwi.affine)
        nib.save(mask, save_path)