def main(argv): del argv if not hasattr(FLAGS.hparams, "items"): FLAGS.hparams = utils.YAMLDictParser().parse(FLAGS.hparams) log_dir = FLAGS.neutra_log_dir utils.BindHParams(FLAGS.hparams) if FLAGS.restore_from_config: with tf.io.gfile.GFile(os.path.join(log_dir, "config")) as f: gin.parse_config(f.read()) tf.io.gfile.makedirs(log_dir) summary_writer = tf.summary.create_file_writer(log_dir, flush_millis=10000) summary_writer.set_as_default() tf.summary.experimental.set_step(0) for i in range(10): try: checkpoint_log_dir = (FLAGS.checkpoint_log_dir if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir) exp = neutra.NeuTraExperiment(log_dir=checkpoint_log_dir) with tf.io.gfile.GFile(os.path.join(log_dir, "config"), "w") as f: f.write(gin.config_str()) logging.info("Config:\n%s", gin.config_str()) checkpoint = checkpoint_log_dir + "/model.ckpt" if tf.io.gfile.exists(checkpoint + ".index"): logging.info("Restoring from %s", checkpoint) exp.checkpoint.restore(checkpoint) with utils.use_xla(False): if FLAGS.mode == "train": Train(exp) elif FLAGS.mode == "objective": TuneObjective(exp) elif FLAGS.mode == "benchmark": Benchmark(exp) elif FLAGS.mode == "eval": Eval(exp) break except tf.errors.InvalidArgumentError as e: if "NaN" in e.message: logging.error(e.message) logging.error("Got a NaN, try: %d", i) else: raise e
tf.enable_v2_behavior() flags.DEFINE_string("neutra_log_dir", "/tmp/neutra", "Output directory for experiment artifacts.") flags.DEFINE_string("checkpoint_log_dir", None, "Output directory for checkpoints, if specified.") flags.DEFINE_enum( "mode", "train", ["eval", "benchmark", "train", "objective"], "Mode for this run. Standard trains bijector, tunes the " "chain parameters and does the evals. Benchmark uses " "the tuned parameters and benchmarks the chain.") flags.DEFINE_boolean( "restore_from_config", False, "Whether to restore the hyperparameters from the " "previous run.") flags.DEFINE(utils.YAMLDictParser(), "hparams", "", "Hyperparameters to override.") flags.DEFINE_string("tune_outputs_name", "tune_outputs", "Name of the tune_outputs file.") flags.DEFINE_string("eval_suffix", "", "Suffix for the eval outputs.") FLAGS = flags.FLAGS def Train(exp): log_dir = (FLAGS.checkpoint_log_dir if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir) logging.info("Training") q_stats, secs_per_step = exp.Train() tf.io.gfile.makedirs(log_dir) utils.save_json(q_stats, os.path.join(log_dir, "q_stats"))