def model_factory(hyperparams, input_size, output_size, volume_manager): if hyperparams['model'] == 'gru_regression': from learn2track.models import GRU_Regression return GRU_Regression(volume_manager=volume_manager, input_size=input_size, hidden_sizes=hyperparams['hidden_sizes'], output_size=output_size, activation=hyperparams['activation'], use_previous_direction=hyperparams['feed_previous_direction'], predict_offset=hyperparams['predict_offset'], use_layer_normalization=hyperparams['use_layer_normalization'], drop_prob=hyperparams['drop_prob'], use_zoneout=hyperparams['use_zoneout'], use_skip_connections=hyperparams['skip_connections'], neighborhood_radius=hyperparams['neighborhood_radius'], learn_to_stop=hyperparams['learn_to_stop'], seed=hyperparams['seed']) elif hyperparams['model'] == 'gru_multistep': from learn2track.models import GRU_Multistep_Gaussian return GRU_Multistep_Gaussian(volume_manager=volume_manager, input_size=input_size, hidden_sizes=hyperparams['hidden_sizes'], target_dims=output_size, k=hyperparams['k'], m=hyperparams['m'], seed=hyperparams['seed'], use_previous_direction=hyperparams['feed_previous_direction'], use_layer_normalization=hyperparams['use_layer_normalization'], drop_prob=hyperparams['drop_prob'], use_zoneout=hyperparams['use_zoneout']) elif hyperparams['model'] == 'gru_mixture': from learn2track.models import GRU_Mixture return GRU_Mixture(volume_manager=volume_manager, input_size=input_size, hidden_sizes=hyperparams['hidden_sizes'], output_size=output_size, n_gaussians=hyperparams['n_gaussians'], activation=hyperparams['activation'], use_previous_direction=hyperparams['feed_previous_direction'], use_layer_normalization=hyperparams['use_layer_normalization'], drop_prob=hyperparams['drop_prob'], use_zoneout=hyperparams['use_zoneout'], use_skip_connections=hyperparams['skip_connections'], neighborhood_radius=hyperparams['neighborhood_radius'], learn_to_stop=hyperparams['learn_to_stop'], seed=hyperparams['seed']) elif hyperparams['model'] == 'gru_gaussian': from learn2track.models import GRU_Gaussian return GRU_Gaussian(volume_manager=volume_manager, input_size=input_size, hidden_sizes=hyperparams['hidden_sizes'], output_size=output_size, use_previous_direction=hyperparams['feed_previous_direction'], use_layer_normalization=hyperparams['use_layer_normalization'], drop_prob=hyperparams['drop_prob'], use_zoneout=hyperparams['use_zoneout'], use_skip_connections=hyperparams['skip_connections'], neighborhood_radius=hyperparams['neighborhood_radius'], learn_to_stop=hyperparams['learn_to_stop'], seed=hyperparams['seed']) elif hyperparams['model'] == 'ffnn_regression': from learn2track.models import FFNN_Regression return FFNN_Regression(volume_manager=volume_manager, input_size=input_size, hidden_sizes=hyperparams['hidden_sizes'], output_size=output_size, activation=hyperparams['activation'], use_previous_direction=hyperparams['feed_previous_direction'], predict_offset=hyperparams['predict_offset'], use_layer_normalization=hyperparams['use_layer_normalization'], dropout_prob=hyperparams['dropout_prob'], use_skip_connections=hyperparams['skip_connections'], neighborhood_radius=hyperparams['neighborhood_radius'], seed=hyperparams['seed']) else: raise ValueError("Unknown model!")