def main(unused_argv):
    logging.info('Loading training data')
    with utils.read_h5py(FLAGS.input_path) as f:
        snapshots = f['v'][...]
        equation_kwargs = {k: v.item() for k, v in f.attrs.items()}

    logging.info('Inputs have shape %r', snapshots.shape)

    if FLAGS.checkpoint_dir:
        tf.gfile.MakeDirs(FLAGS.checkpoint_dir)

    hparams = training.create_hparams(
        FLAGS.equation, equation_kwargs=json.dumps(equation_kwargs))
    hparams.parse(FLAGS.hparams)

    logging.info('Starting training loop')
    metrics_df = training.training_loop(snapshots,
                                        FLAGS.checkpoint_dir,
                                        hparams,
                                        master=FLAGS.master)

    if FLAGS.checkpoint_dir:
        logging.info('Saving CSV with metrics')
        csv_path = os.path.join(FLAGS.checkpoint_dir, 'metrics.csv')
        with tf.gfile.GFile(csv_path, 'w') as f:
            metrics_df.to_csv(f, index=False)

    logging.info('Finished')
Пример #2
0
 def test_training_loop(self, **hparam_values):
     with tf.Graph().as_default():
         snapshots = np.random.RandomState(0).randn(100, NUM_X_POINTS)
         hparams = training.create_hparams(
             learning_rates=[1e-3],
             learning_stops=[20],
             eval_interval=10,
             equation_kwargs=json.dumps({'num_points': NUM_X_POINTS}),
             **hparam_values)
         results = training.training_loop(snapshots, self.tmpdir, hparams)
         self.assertIsInstance(results, pd.DataFrame)
         self.assertEqual(results.shape[0], 2)
Пример #3
0
 def train(self, hparams):
     # train a model on random noise
     with tf.Graph().as_default():
         snapshots = 0.01 * np.random.RandomState(0).randn(
             100, NUM_X_POINTS)
         training.training_loop(snapshots, self.checkpoint_dir, hparams)