Exemplo n.º 1
0
def test_gru_mixture_fprop():
    hidden_sizes = 50

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = batch_schedulers.TractographyBatchScheduler(
            trainset, batch_size=16, noisy_streamlines_sigma=None, seed=1234)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hidden_sizes,
              batch_scheduler.target_size)

    with Timer("Creating model"):
        hyperparams = {
            'model': 'gru_mixture',
            'n_gaussians': 2,
            'SGD': "1e-2",
            'hidden_sizes': hidden_sizes,
            'learn_to_stop': False,
            'normalize': False,
            'feed_previous_direction': False
        }
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    # Test fprop with missing streamlines from one subject in a batch
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)
def test_gru_regression_fprop():
    hidden_sizes = 50

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = batch_schedulers.TractographyBatchScheduler(trainset,
                                                                      batch_size=16,
                                                                      noisy_streamlines_sigma=None,
                                                                      seed=1234)
        print ("An epoch will be composed of {} updates.".format(batch_scheduler.nb_updates_per_epoch))
        print (volume_manager.data_dimension, hidden_sizes, batch_scheduler.target_size)

    with Timer("Creating model"):
        hyperparams = {'model': 'gru_regression',
                       'SGD': "1e-2",
                       'hidden_sizes': hidden_sizes,
                       'learn_to_stop': False,
                       'normalize': False}
        model = factories.model_factory(hyperparams,
                                        input_size=volume_manager.data_dimension,
                                        output_size=batch_scheduler.target_size,
                                        volume_manager=volume_manager)
        model.initialize(factories.weigths_initializer_factory("orthogonal", seed=1234))


    # Test fprop with missing streamlines from one subject in a batch
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs], output, updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)


    fct_loss = theano.function([trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
                                loss.loss,
                                updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)


    fct_optim = theano.function([trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
                                list(optimizer.directions.values()),
                                updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)
Exemplo n.º 3
0
def test_gru_multistep_fprop_k3():
    hidden_sizes = 50

    hyperparams = {
        'model': 'gru_multistep',
        'k': 3,
        'm': 3,
        'batch_size': 16,
        'SGD': "1e-2",
        'hidden_sizes': hidden_sizes,
        'learn_to_stop': False,
        'normalize': False,
        'noisy_streamlines_sigma': None,
        'shuffle_streamlines': True,
        'seed': 1234
    }

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = factories.batch_scheduler_factory(hyperparams,
                                                            trainset,
                                                            train_mode=True)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hidden_sizes,
              batch_scheduler.target_size)

    with Timer("Creating model"):
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    # Test fprop
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)
Exemplo n.º 4
0
def main():
    parser = build_argparser()
    args = parser.parse_args()
    print(args)
    print("Using Theano v.{}".format(theano.version.short_version))

    hyperparams_to_exclude = ['max_epoch', 'force', 'name', 'view', 'shuffle_streamlines']
    # Use this for hyperparams added in a new version, but nonexistent from older versions
    retrocompatibility_defaults = {'feed_previous_direction': False,
                                   'normalize': False}
    experiment_path, hyperparams, resuming = utils.maybe_create_experiment_folder(args, exclude=hyperparams_to_exclude,
                                                                                  retrocompatibility_defaults=retrocompatibility_defaults)

    # Log the command currently running.
    with open(pjoin(experiment_path, 'cmd.txt'), 'a') as f:
        f.write(" ".join(sys.argv) + "\n")

    print("Resuming:" if resuming else "Creating:", experiment_path)

    with Timer("Loading dataset", newline=True):
        trainset_volume_manager = VolumeManager()
        validset_volume_manager = VolumeManager()
        trainset = datasets.load_tractography_dataset(args.train_subjects, trainset_volume_manager, name="trainset", use_sh_coeffs=args.use_sh_coeffs)
        validset = datasets.load_tractography_dataset(args.valid_subjects, validset_volume_manager, name="validset", use_sh_coeffs=args.use_sh_coeffs)
        print("Dataset sizes:", len(trainset), " |", len(validset))

        if args.view:
            tsne_view(trainset, trainset_volume_manager)
            sys.exit(0)

        batch_scheduler = batch_scheduler_factory(hyperparams, dataset=trainset, train_mode=True)
        print("An epoch will be composed of {} updates.".format(batch_scheduler.nb_updates_per_epoch))
        print(trainset_volume_manager.data_dimension, args.hidden_sizes, batch_scheduler.target_size)

    with Timer("Creating model"):
        input_size = trainset_volume_manager.data_dimension
        if hyperparams['feed_previous_direction']:
            input_size += 3

        model = model_factory(hyperparams,
                              input_size=input_size,
                              output_size=batch_scheduler.target_size,
                              volume_manager=trainset_volume_manager)
        model.initialize(weigths_initializer_factory(args.weights_initialization,
                                                     seed=args.initialization_seed))

    with Timer("Building optimizer"):
        loss = loss_factory(hyperparams, model, trainset)

        if args.clip_gradient is not None:
            loss.append_gradient_modifier(DirectionClipping(threshold=args.clip_gradient))

        optimizer = optimizer_factory(hyperparams, loss)

    with Timer("Building trainer"):
        trainer = Trainer(optimizer, batch_scheduler)

        # Log training error
        loss_monitor = views.MonitorVariable(loss.loss)
        avg_loss = tasks.AveragePerEpoch(loss_monitor)
        trainer.append_task(avg_loss)

        # Print average training loss.
        trainer.append_task(tasks.Print("Avg. training loss:         : {}", avg_loss))

        # if args.learn_to_stop:
        #     l2err_monitor = views.MonitorVariable(T.mean(loss.mean_sqr_error))
        #     avg_l2err = tasks.AveragePerEpoch(l2err_monitor)
        #     trainer.append_task(avg_l2err)
        #
        #     crossentropy_monitor = views.MonitorVariable(T.mean(loss.cross_entropy))
        #     avg_crossentropy = tasks.AveragePerEpoch(crossentropy_monitor)
        #     trainer.append_task(avg_crossentropy)
        #
        #     trainer.append_task(tasks.Print("Avg. training L2 err:       : {}", avg_l2err))
        #     trainer.append_task(tasks.Print("Avg. training stopping:     : {}", avg_crossentropy))
        #     trainer.append_task(tasks.Print("L2 err : {0:.4f}", l2err_monitor, each_k_update=100))
        #     trainer.append_task(tasks.Print("stopping : {0:.4f}", crossentropy_monitor, each_k_update=100))

        # Print NLL mean/stderror.
        # train_loss = L2DistanceForSequences(model, trainset)
        # train_batch_scheduler = StreamlinesBatchScheduler(trainset, batch_size=1000,
        #                                                   noisy_streamlines_sigma=None,
        #                                                   nb_updates_per_epoch=None,
        #                                                   seed=1234)

        # train_error = views.LossView(loss=train_loss, batch_scheduler=train_batch_scheduler)
        # trainer.append_task(tasks.Print("Trainset - Error        : {0:.2f} | {1:.2f}", train_error.sum, train_error.mean))

        # HACK: To make sure all subjects in the volume_manager are used in a batch, we have to split the trainset/validset in 2 volume managers
        model.volume_manager = validset_volume_manager
        valid_loss = loss_factory(hyperparams, model, validset)
        valid_batch_scheduler = batch_scheduler_factory(hyperparams,
                                                        dataset=validset,
                                                        train_mode=False)

        valid_error = views.LossView(loss=valid_loss, batch_scheduler=valid_batch_scheduler)
        trainer.append_task(tasks.Print("Validset - Error        : {0:.2f} | {1:.2f}", valid_error.sum, valid_error.mean))

        # HACK: Restore trainset volume manager
        model.volume_manager = trainset_volume_manager

        lookahead_loss = valid_error.sum

        direction_norm = views.MonitorVariable(T.sqrt(sum(map(lambda d: T.sqr(d).sum(), loss.gradients.values()))))
        # trainer.append_task(tasks.Print("||d|| : {0:.4f}", direction_norm))

        # logger = tasks.Logger(train_error.mean, valid_error.mean, valid_error.sum, direction_norm)
        logger = tasks.Logger(valid_error.mean, valid_error.sum, direction_norm)
        trainer.append_task(logger)

        if args.view:
            import pylab as plt

            def _plot(*args, **kwargs):
                plt.figure(1)
                plt.clf()
                plt.show(False)
                plt.subplot(121)
                plt.plot(np.array(logger.get_variable_history(0)).flatten(), label="Train")
                plt.plot(np.array(logger.get_variable_history(1)).flatten(), label="Valid")
                plt.legend()

                plt.subplot(122)
                plt.plot(np.array(logger.get_variable_history(3)).flatten(), label="||d'||")
                plt.draw()

            trainer.append_task(tasks.Callback(_plot))

        # Callback function to stop training if NaN is detected.
        def detect_nan(obj, status):
            if np.isnan(model.parameters[0].get_value().sum()):
                print("NaN detected! Stopping training now.")
                sys.exit()

        trainer.append_task(tasks.Callback(detect_nan, each_k_update=1))

        # Callback function to save training progression.
        def save_training(obj, status):
            trainer.save(experiment_path)

        trainer.append_task(tasks.Callback(save_training))

        # Early stopping with a callback for saving every time model improves.
        def save_improvement(obj, status):
            """ Save best model and training progression. """
            if np.isnan(model.parameters[0].get_value().sum()):
                print("NaN detected! Not saving the model. Crashing now.")
                sys.exit()

            print("*** Best epoch: {0} ***\n".format(obj.best_epoch))
            model.save(experiment_path)

        # Print time for one epoch
        trainer.append_task(tasks.PrintEpochDuration())
        trainer.append_task(tasks.PrintTrainingDuration())
        trainer.append_task(tasks.PrintTime(each_k_update=100))  # Profiling

        # Add stopping criteria
        trainer.append_task(stopping_criteria.MaxEpochStopping(args.max_epoch))
        early_stopping = stopping_criteria.EarlyStopping(lookahead_loss, lookahead=args.lookahead, eps=args.lookahead_eps, callback=save_improvement)
        trainer.append_task(early_stopping)

    with Timer("Compiling Theano graph"):
        trainer.build_theano_graph()

    if resuming:
        if not os.path.isdir(pjoin(experiment_path, 'training')):
            print("No 'training/' folder. Assuming it failed before"
                  " the end of the first epoch. Starting a new training.")
        else:
            with Timer("Loading"):
                trainer.load(experiment_path)

    with Timer("Training"):
        trainer.train()
Exemplo n.º 5
0
def main():
    parser = build_argparser()
    args = parser.parse_args()
    print(args)
    print("Using Theano v.{}".format(theano.version.short_version))

    hyperparams_to_exclude = ['max_epoch', 'force', 'name', 'view', 'shuffle_streamlines']
    # Use this for hyperparams added in a new version, but nonexistent from older versions
    retrocompatibility_defaults = {'feed_previous_direction': False,
                                   'predict_offset': False,
                                   'normalize': False,
                                   'sort_streamlines': False,
                                   'keep_step_size': False,
                                   'use_layer_normalization': False,
                                   'drop_prob': 0.,
                                   'use_zoneout': False,
                                   'skip_connections': False}
    experiment_path, hyperparams, resuming = utils.maybe_create_experiment_folder(args, exclude=hyperparams_to_exclude,
                                                                                  retrocompatibility_defaults=retrocompatibility_defaults)

    # Log the command currently running.
    with open(pjoin(experiment_path, 'cmd.txt'), 'a') as f:
        f.write(" ".join(sys.argv) + "\n")

    print("Resuming:" if resuming else "Creating:", experiment_path)

    with Timer("Loading dataset", newline=True):
        trainset_volume_manager = VolumeManager()
        validset_volume_manager = VolumeManager()
        trainset = datasets.load_tractography_dataset(args.train_subjects, trainset_volume_manager, name="trainset",
                                                      use_sh_coeffs=args.use_sh_coeffs)
        validset = datasets.load_tractography_dataset(args.valid_subjects, validset_volume_manager, name="validset",
                                                      use_sh_coeffs=args.use_sh_coeffs)
        print("Dataset sizes:", len(trainset), " |", len(validset))

        batch_scheduler = batch_scheduler_factory(hyperparams, dataset=trainset, train_mode=True)
        print("An epoch will be composed of {} updates.".format(batch_scheduler.nb_updates_per_epoch))
        print(trainset_volume_manager.data_dimension, args.hidden_sizes, batch_scheduler.target_size)

    with Timer("Creating model"):
        input_size = trainset_volume_manager.data_dimension
        if hyperparams['feed_previous_direction']:
            input_size += 3

        model = model_factory(hyperparams,
                              input_size=input_size,
                              output_size=batch_scheduler.target_size,
                              volume_manager=trainset_volume_manager)
        model.initialize(weigths_initializer_factory(args.weights_initialization,
                                                     seed=args.initialization_seed))

    with Timer("Building optimizer"):
        loss = loss_factory(hyperparams, model, trainset)

        if args.clip_gradient is not None:
            loss.append_gradient_modifier(DirectionClipping(threshold=args.clip_gradient))

        optimizer = optimizer_factory(hyperparams, loss)

    with Timer("Building trainer"):
        trainer = Trainer(optimizer, batch_scheduler)

        # Log training error
        loss_monitor = views.MonitorVariable(loss.loss)
        avg_loss = tasks.AveragePerEpoch(loss_monitor)
        trainer.append_task(avg_loss)

        # Print average training loss.
        trainer.append_task(tasks.Print("Avg. training loss:         : {}", avg_loss))

        # if args.learn_to_stop:
        #     l2err_monitor = views.MonitorVariable(T.mean(loss.mean_sqr_error))
        #     avg_l2err = tasks.AveragePerEpoch(l2err_monitor)
        #     trainer.append_task(avg_l2err)
        #
        #     crossentropy_monitor = views.MonitorVariable(T.mean(loss.cross_entropy))
        #     avg_crossentropy = tasks.AveragePerEpoch(crossentropy_monitor)
        #     trainer.append_task(avg_crossentropy)
        #
        #     trainer.append_task(tasks.Print("Avg. training L2 err:       : {}", avg_l2err))
        #     trainer.append_task(tasks.Print("Avg. training stopping:     : {}", avg_crossentropy))
        #     trainer.append_task(tasks.Print("L2 err : {0:.4f}", l2err_monitor, each_k_update=100))
        #     trainer.append_task(tasks.Print("stopping : {0:.4f}", crossentropy_monitor, each_k_update=100))

        # Print NLL mean/stderror.
        # train_loss = L2DistanceForSequences(model, trainset)
        # train_batch_scheduler = StreamlinesBatchScheduler(trainset, batch_size=1000,
        #                                                   noisy_streamlines_sigma=None,
        #                                                   nb_updates_per_epoch=None,
        #                                                   seed=1234)

        # train_error = views.LossView(loss=train_loss, batch_scheduler=train_batch_scheduler)
        # trainer.append_task(tasks.Print("Trainset - Error        : {0:.2f} | {1:.2f}", train_error.sum, train_error.mean))

        # HACK: To make sure all subjects in the volume_manager are used in a batch, we have to split the trainset/validset in 2 volume managers
        model.volume_manager = validset_volume_manager
        model.drop_prob = 0.  # Do not use dropout/zoneout for evaluation
        valid_loss = loss_factory(hyperparams, model, validset)
        valid_batch_scheduler = batch_scheduler_factory(hyperparams,
                                                        dataset=validset,
                                                        train_mode=False)

        valid_error = views.LossView(loss=valid_loss, batch_scheduler=valid_batch_scheduler)
        trainer.append_task(tasks.Print("Validset - Error        : {0:.2f} | {1:.2f}", valid_error.sum, valid_error.mean))

        if hyperparams['model'] == 'ffnn_regression':
            valid_batch_scheduler2 = batch_scheduler_factory(hyperparams,
                                                             dataset=validset,
                                                             train_mode=False)

            valid_l2 = loss_factory(hyperparams, model, validset, loss_type="expected_value")
            valid_l2_error = views.LossView(loss=valid_l2, batch_scheduler=valid_batch_scheduler2)
            trainer.append_task(tasks.Print("Validset - {}".format(valid_l2.__class__.__name__) + "\t: {0:.2f} | {1:.2f}", valid_l2_error.sum, valid_l2_error.mean))

        # HACK: Restore trainset volume manager
        model.volume_manager = trainset_volume_manager
        model.drop_prob = hyperparams['drop_prob']  # Restore dropout

        lookahead_loss = valid_error.sum

        direction_norm = views.MonitorVariable(T.sqrt(sum(map(lambda d: T.sqr(d).sum(), loss.gradients.values()))))
        # trainer.append_task(tasks.Print("||d|| : {0:.4f}", direction_norm))

        # logger = tasks.Logger(train_error.mean, valid_error.mean, valid_error.sum, direction_norm)
        logger = tasks.Logger(valid_error.mean, valid_error.sum, direction_norm)
        trainer.append_task(logger)

        if args.view:
            import pylab as plt

            def _plot(*args, **kwargs):
                plt.figure(1)
                plt.clf()
                plt.show(False)
                plt.subplot(121)
                plt.plot(np.array(logger.get_variable_history(0)).flatten(), label="Train")
                plt.plot(np.array(logger.get_variable_history(1)).flatten(), label="Valid")
                plt.legend()

                plt.subplot(122)
                plt.plot(np.array(logger.get_variable_history(3)).flatten(), label="||d'||")
                plt.draw()

            trainer.append_task(tasks.Callback(_plot))

        # Callback function to stop training if NaN is detected.
        def detect_nan(obj, status):
            if np.isnan(model.parameters[0].get_value().sum()):
                print("NaN detected! Stopping training now.")
                sys.exit()

        trainer.append_task(tasks.Callback(detect_nan, each_k_update=1))

        # Callback function to save training progression.
        def save_training(obj, status):
            trainer.save(experiment_path)

        trainer.append_task(tasks.Callback(save_training))

        # Early stopping with a callback for saving every time model improves.
        def save_improvement(obj, status):
            """ Save best model and training progression. """
            if np.isnan(model.parameters[0].get_value().sum()):
                print("NaN detected! Not saving the model. Crashing now.")
                sys.exit()

            print("*** Best epoch: {0} ***\n".format(obj.best_epoch))
            model.save(experiment_path)

        # Print time for one epoch
        trainer.append_task(tasks.PrintEpochDuration())
        trainer.append_task(tasks.PrintTrainingDuration())
        trainer.append_task(tasks.PrintTime(each_k_update=100))  # Profiling

        # Add stopping criteria
        trainer.append_task(stopping_criteria.MaxEpochStopping(args.max_epoch))
        early_stopping = stopping_criteria.EarlyStopping(lookahead_loss, lookahead=args.lookahead, eps=args.lookahead_eps, callback=save_improvement)
        trainer.append_task(early_stopping)

    with Timer("Compiling Theano graph"):
        trainer.build_theano_graph()

    if resuming:
        if not os.path.isdir(pjoin(experiment_path, 'training')):
            print("No 'training/' folder. Assuming it failed before"
                  " the end of the first epoch. Starting a new training.")
        else:
            with Timer("Loading"):
                trainer.load(experiment_path)

    with Timer("Training"):
        trainer.train()
Exemplo n.º 6
0
def test_gru_mixture_track():
    hidden_sizes = 50

    with Timer("Creating dummy volume", newline=True):
        volume_manager = neurotools.VolumeManager()
        dwi, gradients = make_dummy_dwi(nb_gradients=30,
                                        volume_shape=(10, 10, 10),
                                        seed=1234)
        volume = neurotools.resample_dwi(dwi, gradients.bvals,
                                         gradients.bvecs).astype(np.float32)

        volume_manager.register(volume)

    with Timer("Creating model"):
        hyperparams = {
            'model': 'gru_mixture',
            'SGD': "1e-2",
            'hidden_sizes': hidden_sizes,
            'learn_to_stop': False,
            'normalize': False,
            'activation': 'tanh',
            'feed_previous_direction': False,
            'predict_offset': False,
            'use_layer_normalization': False,
            'drop_prob': 0.,
            'use_zoneout': False,
            'skip_connections': False,
            'neighborhood_radius': None,
            'nb_seeds_per_voxel': 2,
            'step_size': 0.5,
            'batch_size': 200,
            'n_gaussians': 2,
            'seed': 1234
        }
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=3,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

    rng = np.random.RandomState(1234)
    mask = np.ones(volume.shape[:3])
    seeding_mask = np.random.randint(2, size=mask.shape)
    seeds = []
    indices = np.array(np.where(seeding_mask)).T
    for idx in indices:
        seeds_in_voxel = idx + rng.uniform(
            -0.5, 0.5, size=(hyperparams['nb_seeds_per_voxel'], 3))
        seeds.extend(seeds_in_voxel)
    seeds = np.array(seeds, dtype=theano.config.floatX)

    is_outside_mask = make_is_outside_mask(mask, np.eye(4), threshold=0.5)
    is_too_long = make_is_too_long(150)
    is_too_curvy = make_is_too_curvy(np.rad2deg(30))
    is_unlikely = make_is_unlikely(0.5)
    is_stopping = make_is_stopping({
        STOPPING_MASK: is_outside_mask,
        STOPPING_LENGTH: is_too_long,
        STOPPING_CURVATURE: is_too_curvy,
        STOPPING_LIKELIHOOD: is_unlikely
    })
    is_stopping.max_nb_points = 150

    args = SimpleNamespace()
    args.track_like_peter = False
    args.pft_nb_retry = 0
    args.pft_nb_backtrack_steps = 0
    args.use_max_component = False
    args.flip_x = False
    args.flip_y = False
    args.flip_z = False
    args.verbose = True

    tractogram = batch_track(model,
                             volume,
                             seeds,
                             step_size=hyperparams['step_size'],
                             is_stopping=is_stopping,
                             batch_size=hyperparams['batch_size'],
                             args=args)

    return True
Exemplo n.º 7
0
def test_gru_mixture_fprop_neighborhood():
    hyperparams = {
        'model': 'gru_mixture',
        'SGD': "1e-2",
        'hidden_sizes': 50,
        'batch_size': 16,
        'learn_to_stop': False,
        'normalize': True,
        'activation': 'tanh',
        'feed_previous_direction': False,
        'predict_offset': False,
        'use_layer_normalization': False,
        'drop_prob': 0.,
        'use_zoneout': False,
        'skip_connections': False,
        'seed': 1234,
        'noisy_streamlines_sigma': None,
        'keep_step_size': True,
        'sort_streamlines': False,
        'n_gaussians': 2,
        'neighborhood_radius': 0.5
    }

    with Timer("Creating dataset", newline=True):
        volume_manager = neurotools.VolumeManager()
        trainset = make_dummy_dataset(volume_manager)
        print("Dataset sizes:", len(trainset))

        batch_scheduler = factories.batch_scheduler_factory(hyperparams,
                                                            dataset=trainset)
        print("An epoch will be composed of {} updates.".format(
            batch_scheduler.nb_updates_per_epoch))
        print(volume_manager.data_dimension, hyperparams['hidden_sizes'],
              batch_scheduler.target_size)

    with Timer("Creating model"):
        model = factories.model_factory(
            hyperparams,
            input_size=volume_manager.data_dimension,
            output_size=batch_scheduler.target_size,
            volume_manager=volume_manager)
        model.initialize(
            factories.weigths_initializer_factory("orthogonal", seed=1234))

        print("Input size: {}".format(model.model_input_size))

    # Test fprop with missing streamlines from one subject in a batch
    output = model.get_output(trainset.symb_inputs)
    fct = theano.function([trainset.symb_inputs],
                          output,
                          updates=model.graph_updates)

    batch_inputs, batch_targets, batch_mask = batch_scheduler._next_batch(2)
    out = fct(batch_inputs)

    with Timer("Building optimizer"):
        loss = factories.loss_factory(hyperparams, model, trainset)
        optimizer = factories.optimizer_factory(hyperparams, loss)

    fct_loss = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        loss.loss,
        updates=model.graph_updates)

    loss_value = fct_loss(batch_inputs, batch_targets, batch_mask)
    print("Loss:", loss_value)

    fct_optim = theano.function(
        [trainset.symb_inputs, trainset.symb_targets, trainset.symb_mask],
        list(optimizer.directions.values()),
        updates=model.graph_updates)

    dirs = fct_optim(batch_inputs, batch_targets, batch_mask)

    return True