예제 #1
0
    def testNeuTraExperiment(self):
        gin.clear_config()
        gin.bind_parameter("target_spec.name", "easy_gaussian")
        gin.bind_parameter("target_spec.num_dims", 2)
        exp = neutra.NeuTraExperiment(train_batch_size=2,
                                      test_chain_batch_size=2,
                                      bijector="affine",
                                      log_dir=self.temp_dir)

        with tf.Session() as sess:
            exp.Initialize(sess)
            exp.TrainBijector(sess, 1)
            exp.Eval(sess)
            exp.Benchmark(sess)
            exp.Tune(sess, method="random", max_num_trials=1)
def main(argv):
    del argv

    log_dir = FLAGS.neutra_log_dir
    utils.BindHParams(FLAGS.hparams)
    if FLAGS.restore_from_config:
        with tf.gfile.Open(os.path.join(log_dir, "config")) as f:
            gin.parse_config(f.read())

    tf.gfile.MakeDirs(log_dir)
    summary_writer = tf.contrib.summary.create_file_writer(log_dir,
                                                           flush_millis=10000)
    summary_writer.set_as_default()
    with tf.contrib.summary.always_record_summaries():
        exp = neutra.NeuTraExperiment(log_dir=log_dir)
        with tf.gfile.Open(os.path.join(log_dir, "config"), "w") as f:
            f.write(gin.operative_config_str())
            tf.logging.info("Config:\n%s", gin.operative_config_str())

        with tf.Session() as sess:
            exp.Initialize(sess)
            tf.contrib.summary.initialize(graph=tf.get_default_graph())

            checkpoint = tf.train.latest_checkpoint(log_dir)
            if checkpoint:
                tf.logging.info("Restoring from %s", checkpoint)
                exp.saver.restore(sess, checkpoint)

            if FLAGS.mode == "standard":
                Train(exp, sess)
                Benchmark(exp, sess)
            elif FLAGS.mode == "benchmark":
                Benchmark(exp, sess)
            elif FLAGS.mode == "eval":
                Benchmark(exp, sess)
                Eval(exp, sess)
            elif FLAGS.mode == "all":
                Train(exp, sess)
                Benchmark(exp, sess)
                Eval(exp, sess)