def testNeuTraExperiment(self): gin.clear_config() gin.bind_parameter("target_spec.name", "easy_gaussian") gin.bind_parameter("target_spec.num_dims", 2) exp = neutra.NeuTraExperiment(train_batch_size=2, test_chain_batch_size=2, bijector="affine", log_dir=self.temp_dir) with tf.Session() as sess: exp.Initialize(sess) exp.TrainBijector(sess, 1) exp.Eval(sess) exp.Benchmark(sess) exp.Tune(sess, method="random", max_num_trials=1)
def main(argv): del argv log_dir = FLAGS.neutra_log_dir utils.BindHParams(FLAGS.hparams) if FLAGS.restore_from_config: with tf.gfile.Open(os.path.join(log_dir, "config")) as f: gin.parse_config(f.read()) tf.gfile.MakeDirs(log_dir) summary_writer = tf.contrib.summary.create_file_writer(log_dir, flush_millis=10000) summary_writer.set_as_default() with tf.contrib.summary.always_record_summaries(): exp = neutra.NeuTraExperiment(log_dir=log_dir) with tf.gfile.Open(os.path.join(log_dir, "config"), "w") as f: f.write(gin.operative_config_str()) tf.logging.info("Config:\n%s", gin.operative_config_str()) with tf.Session() as sess: exp.Initialize(sess) tf.contrib.summary.initialize(graph=tf.get_default_graph()) checkpoint = tf.train.latest_checkpoint(log_dir) if checkpoint: tf.logging.info("Restoring from %s", checkpoint) exp.saver.restore(sess, checkpoint) if FLAGS.mode == "standard": Train(exp, sess) Benchmark(exp, sess) elif FLAGS.mode == "benchmark": Benchmark(exp, sess) elif FLAGS.mode == "eval": Benchmark(exp, sess) Eval(exp, sess) elif FLAGS.mode == "all": Train(exp, sess) Benchmark(exp, sess) Eval(exp, sess)