def test_GraphModel(self, n_recurrences, mlp_sizes): graphs_tuple = self._get_graphs_tuple() output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences, mlp_sizes=mlp_sizes)(graphs_tuple) self.assertListEqual(output_op.shape.as_list(), [len(self._types)]) # Tests if the model runs without crashing. with self.session(): tf.global_variables_initializer().run() output_op.eval()
def train_model(train_file_pattern, test_file_pattern, max_files_to_load = None, n_epochs = 1000, time_index = 9, augment_data_using_rotations = True, learning_rate = 1e-4, grad_clip = 1.0, n_recurrences = 7, mlp_sizes = (64, 64), mlp_kwargs = None, edge_threshold = 2.0, measurement_store_interval = 1000, checkpoint_path = None): """Trains GraphModel using tensorflow. Args: train_file_pattern: pattern matching the files with the training data. test_file_pattern: pattern matching the files with the test data. max_files_to_load: the maximum number of train and test files to load. If None, all files will be loaded. n_epochs: the number of passes through the training dataset (epochs). time_index: the time index (0-9) of the target mobilities. augment_data_using_rotations: data is augemented by using random rotations. learning_rate: the learning rate used by the optimizer. grad_clip: all gradients are clipped to the given value. n_recurrences: the number of message passing steps in the graphnet. mlp_sizes: the number of neurons in each layer of the MLP. mlp_kwargs: additional keyword aguments passed to the MLP. edge_threshold: particles at distance less than threshold are connected by an edge. measurement_store_interval: number of steps between storing objective values (loss and correlation). checkpoint_path: path used to store the checkpoint with the highest correlation on the test set. Returns: Correlation on the test dataset of best model encountered during training. """ if mlp_kwargs is None: mlp_kwargs = dict(initializers=dict(w=tf.variance_scaling_initializer(1.0), b=tf.variance_scaling_initializer(0.1))) # Loads train and test dataset. dataset_kwargs = dict( time_index=time_index, max_files_to_load=max_files_to_load) training_data = load_data(train_file_pattern, **dataset_kwargs) test_data = load_data(test_file_pattern, **dataset_kwargs) # Defines wrapper functions, which can directly be passed to the # tf.data.Dataset.map function. def _make_graph_from_static_structure(static_structure): """Converts static structure to graph, targets and types.""" return (graph_model.make_graph_from_static_structure( static_structure.positions, static_structure.types, static_structure.box, edge_threshold), static_structure.targets, static_structure.types) def _apply_random_rotation(graph, targets, types): """Applies random rotations to the graph and forwards targets and types.""" return graph_model.apply_random_rotation(graph), targets, types # Defines data-pipeline based on tf.data.Dataset following the official # guideline: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays. # We use initializable iterators to avoid embedding the training and test data # directly into the graph. # Instead we feed the data to the iterators during the initalization of the # iterators before the main training loop. placeholders = GlassSimulationData._make( tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0]) dataset = tf.data.Dataset.from_tensor_slices(placeholders) dataset = dataset.map(_make_graph_from_static_structure) dataset = dataset.cache() dataset = dataset.shuffle(400) # Augments data. This has to be done after calling dataset.cache! if augment_data_using_rotations: dataset = dataset.map(_apply_random_rotation) dataset = dataset.repeat() train_iterator = dataset.make_initializable_iterator() dataset = tf.data.Dataset.from_tensor_slices(placeholders) dataset = dataset.map(_make_graph_from_static_structure) dataset = dataset.cache() dataset = dataset.repeat() test_iterator = dataset.make_initializable_iterator() # Creates tensorflow graph. # Note: We decouple the training and test datasets from the input pipeline # by creating a new iterator from a string-handle placeholder with the same # output types and shapes as the training dataset. dataset_handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( dataset_handle, train_iterator.output_types, train_iterator.output_shapes) graph, targets, types = iterator.get_next() model = graph_model.GraphBasedModel( n_recurrences, mlp_sizes, mlp_kwargs) prediction = model(graph) # Defines loss and minimization operations. loss_ops = get_loss_ops(prediction, targets, types) minimize_op = get_minimize_op(loss_ops.l2_loss, learning_rate, grad_clip) best_so_far = -1 train_stats = [] test_stats = [] saver = tf.train.Saver() with tf.train.SingularMonitoredSession() as session: # Initializes train and test iterators with the training and test datasets. # The obtained training and test string-handles can be passed to the # dataset_handle placeholder to select the dataset. train_handle = session.run(train_iterator.string_handle()) test_handle = session.run(test_iterator.string_handle()) feed_dict = {p: [x[i] for x in training_data] for i, p in enumerate(placeholders)} session.run(train_iterator.initializer, feed_dict=feed_dict) feed_dict = {p: [x[i] for x in test_data] for i, p in enumerate(placeholders)} session.run(test_iterator.initializer, feed_dict=feed_dict) # Trains model using stochatic gradient descent on the training dataset. n_training_steps = len(training_data) * n_epochs for i in range(n_training_steps): feed_dict = {dataset_handle: train_handle} train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict) train_stats.append(train_loss) if (i+1) % measurement_store_interval == 0: # Evaluates model on test dataset. for _ in range(len(test_data)): feed_dict = {dataset_handle: test_handle} test_stats.append(session.run(loss_ops, feed_dict=feed_dict)) # Outputs performance statistics on training and test dataset. _log_stats_and_return_mean_correlation('Train', train_stats) correlation = _log_stats_and_return_mean_correlation('Test', test_stats) train_stats = [] test_stats = [] # Updates best model based on the observed correlation on the test # dataset. if correlation > best_so_far: best_so_far = correlation if checkpoint_path: saver.save(session.raw_session(), checkpoint_path) return best_so_far