示例#1
0
 def test_GraphModel(self, n_recurrences, mlp_sizes):
   graphs_tuple = self._get_graphs_tuple()
   output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
                                           mlp_sizes=mlp_sizes)(graphs_tuple)
   self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
   # Tests if the model runs without crashing.
   with self.session():
     tf.global_variables_initializer().run()
     output_op.eval()
示例#2
0
def train_model(train_file_pattern,
                test_file_pattern,
                max_files_to_load = None,
                n_epochs = 1000,
                time_index = 9,
                augment_data_using_rotations = True,
                learning_rate = 1e-4,
                grad_clip = 1.0,
                n_recurrences = 7,
                mlp_sizes = (64, 64),
                mlp_kwargs = None,
                edge_threshold = 2.0,
                measurement_store_interval = 1000,
                checkpoint_path = None):
  """Trains GraphModel using tensorflow.

  Args:
    train_file_pattern: pattern matching the files with the training data.
    test_file_pattern: pattern matching the files with the test data.
    max_files_to_load: the maximum number of train and test files to load.
      If None, all files will be loaded.
    n_epochs: the number of passes through the training dataset (epochs).
    time_index: the time index (0-9) of the target mobilities.
    augment_data_using_rotations: data is augemented by using random rotations.
    learning_rate: the learning rate used by the optimizer.
    grad_clip: all gradients are clipped to the given value.
    n_recurrences: the number of message passing steps in the graphnet.
    mlp_sizes: the number of neurons in each layer of the MLP.
    mlp_kwargs: additional keyword aguments passed to the MLP.
    edge_threshold: particles at distance less than threshold are connected by
      an edge.
    measurement_store_interval: number of steps between storing objective values
      (loss and correlation).
    checkpoint_path: path used to store the checkpoint with the highest
      correlation on the test set.

  Returns:
    Correlation on the test dataset of best model encountered during training.
  """
  if mlp_kwargs is None:
    mlp_kwargs = dict(initializers=dict(w=tf.variance_scaling_initializer(1.0),
                                        b=tf.variance_scaling_initializer(0.1)))
  # Loads train and test dataset.
  dataset_kwargs = dict(
      time_index=time_index,
      max_files_to_load=max_files_to_load)
  training_data = load_data(train_file_pattern, **dataset_kwargs)
  test_data = load_data(test_file_pattern, **dataset_kwargs)

  # Defines wrapper functions, which can directly be passed to the
  # tf.data.Dataset.map function.
  def _make_graph_from_static_structure(static_structure):
    """Converts static structure to graph, targets and types."""
    return (graph_model.make_graph_from_static_structure(
        static_structure.positions,
        static_structure.types,
        static_structure.box,
        edge_threshold),
            static_structure.targets,
            static_structure.types)

  def _apply_random_rotation(graph, targets, types):
    """Applies random rotations to the graph and forwards targets and types."""
    return graph_model.apply_random_rotation(graph), targets, types

  # Defines data-pipeline based on tf.data.Dataset following the official
  # guideline: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays.
  # We use initializable iterators to avoid embedding the training and test data
  # directly into the graph.
  # Instead we feed the data to the iterators during the initalization of the
  # iterators before the main training loop.
  placeholders = GlassSimulationData._make(
      tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
  dataset = tf.data.Dataset.from_tensor_slices(placeholders)
  dataset = dataset.map(_make_graph_from_static_structure)
  dataset = dataset.cache()
  dataset = dataset.shuffle(400)
  # Augments data. This has to be done after calling dataset.cache!
  if augment_data_using_rotations:
    dataset = dataset.map(_apply_random_rotation)
  dataset = dataset.repeat()
  train_iterator = dataset.make_initializable_iterator()

  dataset = tf.data.Dataset.from_tensor_slices(placeholders)
  dataset = dataset.map(_make_graph_from_static_structure)
  dataset = dataset.cache()
  dataset = dataset.repeat()
  test_iterator = dataset.make_initializable_iterator()

  # Creates tensorflow graph.
  # Note: We decouple the training and test datasets from the input pipeline
  # by creating a new iterator from a string-handle placeholder with the same
  # output types and shapes as the training dataset.
  dataset_handle = tf.placeholder(tf.string, shape=[])
  iterator = tf.data.Iterator.from_string_handle(
      dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
  graph, targets, types = iterator.get_next()

  model = graph_model.GraphBasedModel(
      n_recurrences, mlp_sizes, mlp_kwargs)
  prediction = model(graph)

  # Defines loss and minimization operations.
  loss_ops = get_loss_ops(prediction, targets, types)
  minimize_op = get_minimize_op(loss_ops.l2_loss, learning_rate, grad_clip)

  best_so_far = -1
  train_stats = []
  test_stats = []

  saver = tf.train.Saver()

  with tf.train.SingularMonitoredSession() as session:
    # Initializes train and test iterators with the training and test datasets.
    # The obtained training and test string-handles can be passed to the
    # dataset_handle placeholder to select the dataset.
    train_handle = session.run(train_iterator.string_handle())
    test_handle = session.run(test_iterator.string_handle())
    feed_dict = {p: [x[i] for x in training_data]
                 for i, p in enumerate(placeholders)}
    session.run(train_iterator.initializer, feed_dict=feed_dict)
    feed_dict = {p: [x[i] for x in test_data]
                 for i, p in enumerate(placeholders)}
    session.run(test_iterator.initializer, feed_dict=feed_dict)

    # Trains model using stochatic gradient descent on the training dataset.
    n_training_steps = len(training_data) * n_epochs
    for i in range(n_training_steps):
      feed_dict = {dataset_handle: train_handle}
      train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
      train_stats.append(train_loss)

      if (i+1) % measurement_store_interval == 0:
        # Evaluates model on test dataset.
        for _ in range(len(test_data)):
          feed_dict = {dataset_handle: test_handle}
          test_stats.append(session.run(loss_ops, feed_dict=feed_dict))

        # Outputs performance statistics on training and test dataset.
        _log_stats_and_return_mean_correlation('Train', train_stats)
        correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
        train_stats = []
        test_stats = []

        # Updates best model based on the observed correlation on the test
        # dataset.
        if correlation > best_so_far:
          best_so_far = correlation
          if checkpoint_path:
            saver.save(session.raw_session(), checkpoint_path)

  return best_so_far