Example #1
0
def main(argv):
    del argv
    if not hasattr(FLAGS.hparams, "items"):
        FLAGS.hparams = utils.YAMLDictParser().parse(FLAGS.hparams)

    log_dir = FLAGS.neutra_log_dir
    utils.BindHParams(FLAGS.hparams)
    if FLAGS.restore_from_config:
        with tf.io.gfile.GFile(os.path.join(log_dir, "config")) as f:
            gin.parse_config(f.read())

    tf.io.gfile.makedirs(log_dir)
    summary_writer = tf.summary.create_file_writer(log_dir, flush_millis=10000)
    summary_writer.set_as_default()
    tf.summary.experimental.set_step(0)

    for i in range(10):
        try:
            checkpoint_log_dir = (FLAGS.checkpoint_log_dir
                                  if FLAGS.checkpoint_log_dir else
                                  FLAGS.neutra_log_dir)
            exp = neutra.NeuTraExperiment(log_dir=checkpoint_log_dir)
            with tf.io.gfile.GFile(os.path.join(log_dir, "config"), "w") as f:
                f.write(gin.config_str())
            logging.info("Config:\n%s", gin.config_str())

            checkpoint = checkpoint_log_dir + "/model.ckpt"
            if tf.io.gfile.exists(checkpoint + ".index"):
                logging.info("Restoring from %s", checkpoint)
                exp.checkpoint.restore(checkpoint)

            with utils.use_xla(False):
                if FLAGS.mode == "train":
                    Train(exp)
                elif FLAGS.mode == "objective":
                    TuneObjective(exp)
                elif FLAGS.mode == "benchmark":
                    Benchmark(exp)
                elif FLAGS.mode == "eval":
                    Eval(exp)
                break
        except tf.errors.InvalidArgumentError as e:
            if "NaN" in e.message:
                logging.error(e.message)
                logging.error("Got a NaN, try: %d", i)
            else:
                raise e
Example #2
0
tf.enable_v2_behavior()

flags.DEFINE_string("neutra_log_dir", "/tmp/neutra",
                    "Output directory for experiment artifacts.")
flags.DEFINE_string("checkpoint_log_dir", None,
                    "Output directory for checkpoints, if specified.")
flags.DEFINE_enum(
    "mode", "train", ["eval", "benchmark", "train", "objective"],
    "Mode for this run. Standard trains bijector, tunes the "
    "chain parameters and does the evals. Benchmark uses "
    "the tuned parameters and benchmarks the chain.")
flags.DEFINE_boolean(
    "restore_from_config", False,
    "Whether to restore the hyperparameters from the "
    "previous run.")
flags.DEFINE(utils.YAMLDictParser(), "hparams", "",
             "Hyperparameters to override.")
flags.DEFINE_string("tune_outputs_name", "tune_outputs",
                    "Name of the tune_outputs file.")
flags.DEFINE_string("eval_suffix", "", "Suffix for the eval outputs.")

FLAGS = flags.FLAGS


def Train(exp):
    log_dir = (FLAGS.checkpoint_log_dir
               if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir)
    logging.info("Training")
    q_stats, secs_per_step = exp.Train()
    tf.io.gfile.makedirs(log_dir)
    utils.save_json(q_stats, os.path.join(log_dir, "q_stats"))