Esempio n. 1
0
def model_factory(hyperparams, input_size, output_size, volume_manager):

    if hyperparams['model'] == 'gru_regression':
        from learn2track.models import GRU_Regression
        return GRU_Regression(volume_manager=volume_manager,
                              input_size=input_size,
                              hidden_sizes=hyperparams['hidden_sizes'],
                              output_size=output_size,
                              activation=hyperparams['activation'],
                              use_previous_direction=hyperparams['feed_previous_direction'],
                              predict_offset=hyperparams['predict_offset'],
                              use_layer_normalization=hyperparams['use_layer_normalization'],
                              drop_prob=hyperparams['drop_prob'],
                              use_zoneout=hyperparams['use_zoneout'],
                              use_skip_connections=hyperparams['skip_connections'],
                              neighborhood_radius=hyperparams['neighborhood_radius'],
                              learn_to_stop=hyperparams['learn_to_stop'],
                              seed=hyperparams['seed'])

    elif hyperparams['model'] == 'gru_multistep':
        from learn2track.models import GRU_Multistep_Gaussian
        return GRU_Multistep_Gaussian(volume_manager=volume_manager,
                                      input_size=input_size,
                                      hidden_sizes=hyperparams['hidden_sizes'],
                                      target_dims=output_size,
                                      k=hyperparams['k'],
                                      m=hyperparams['m'],
                                      seed=hyperparams['seed'],
                                      use_previous_direction=hyperparams['feed_previous_direction'],
                                      use_layer_normalization=hyperparams['use_layer_normalization'],
                                      drop_prob=hyperparams['drop_prob'],
                                      use_zoneout=hyperparams['use_zoneout'])

    elif hyperparams['model'] == 'gru_mixture':
        from learn2track.models import GRU_Mixture
        return GRU_Mixture(volume_manager=volume_manager,
                           input_size=input_size,
                           hidden_sizes=hyperparams['hidden_sizes'],
                           output_size=output_size,
                           n_gaussians=hyperparams['n_gaussians'],
                           activation=hyperparams['activation'],
                           use_previous_direction=hyperparams['feed_previous_direction'],
                           use_layer_normalization=hyperparams['use_layer_normalization'],
                           drop_prob=hyperparams['drop_prob'],
                           use_zoneout=hyperparams['use_zoneout'],
                           use_skip_connections=hyperparams['skip_connections'],
                           neighborhood_radius=hyperparams['neighborhood_radius'],
                           learn_to_stop=hyperparams['learn_to_stop'],
                           seed=hyperparams['seed'])

    elif hyperparams['model'] == 'gru_gaussian':
        from learn2track.models import GRU_Gaussian
        return GRU_Gaussian(volume_manager=volume_manager,
                            input_size=input_size,
                            hidden_sizes=hyperparams['hidden_sizes'],
                            output_size=output_size,
                            use_previous_direction=hyperparams['feed_previous_direction'],
                            use_layer_normalization=hyperparams['use_layer_normalization'],
                            drop_prob=hyperparams['drop_prob'],
                            use_zoneout=hyperparams['use_zoneout'],
                            use_skip_connections=hyperparams['skip_connections'],
                            neighborhood_radius=hyperparams['neighborhood_radius'],
                            learn_to_stop=hyperparams['learn_to_stop'],
                            seed=hyperparams['seed'])

    elif hyperparams['model'] == 'ffnn_regression':
        from learn2track.models import FFNN_Regression
        return FFNN_Regression(volume_manager=volume_manager,
                               input_size=input_size,
                               hidden_sizes=hyperparams['hidden_sizes'],
                               output_size=output_size,
                               activation=hyperparams['activation'],
                               use_previous_direction=hyperparams['feed_previous_direction'],
                               predict_offset=hyperparams['predict_offset'],
                               use_layer_normalization=hyperparams['use_layer_normalization'],
                               dropout_prob=hyperparams['dropout_prob'],
                               use_skip_connections=hyperparams['skip_connections'],
                               neighborhood_radius=hyperparams['neighborhood_radius'],
                               seed=hyperparams['seed'])

    else:
        raise ValueError("Unknown model!")