Beispiel #1
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    assert FLAGS.seq_len > 0
    assert FLAGS.perm_size > 0

    FLAGS.n_token = data_utils.VOCAB_SIZE
    tf.logging.info("n_token {}".format(FLAGS.n_token))

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

    # Get train input function
    train_input_fn, train_record_info_dict = get_input_fn("train")

    tf.logging.info("num of batches {}".format(
        train_record_info_dict["num_batch"]))

    # Get train cache function
    train_cache_fn = get_cache_fn(FLAGS.mem_len)

    ##### Get model function
    model_fn = get_model_fn()

    ##### Create TPUEstimator
    # TPU Configuration
    run_config = model_utils.configure_tpu(FLAGS)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={"track_mean": FLAGS.track_mean},
        train_batch_size=FLAGS.train_batch_size,
        eval_on_tpu=FLAGS.use_tpu)

    hooks = None
    if FLAGS.debug:
        if FLAGS.debug_dump_dir:
            hooks = [tf_debug.DumpingDebugHook(FLAGS.debug_dump_dir)]
        else:
            hooks = [tf_debug.LocalCLIDebugHook()]
    #### Training
    estimator.train(input_fn=train_input_fn,
                    max_steps=FLAGS.train_steps,
                    hooks=hooks)
Beispiel #2
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    assert FLAGS.seq_len > 0
    assert FLAGS.perm_size > 0

    FLAGS.n_token = data_utils.VOCAB_SIZE
    tf.logging.info('n_token {}'.format(FLAGS.n_token))

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

    # Get train input function
    train_input_fn, train_record_info_dict = get_input_fn('train')

    tf.logging.info(
        'num of batches {}'.format(train_record_info_dict['num_batch'])
    )

    # Get train cache function
    train_cache_fn = get_cache_fn(FLAGS.mem_len)

    ##### Get model function
    model_fn = get_model_fn()

    ##### Create TPUEstimator
    # TPU Configuration
    run_config = model_utils.configure_tpu(FLAGS)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn = model_fn,
        train_cache_fn = train_cache_fn,
        use_tpu = FLAGS.use_tpu,
        config = run_config,
        params = {'track_mean': FLAGS.track_mean},
        train_batch_size = FLAGS.train_batch_size,
        eval_on_tpu = FLAGS.use_tpu,
    )

    #### Training
    estimator.train(input_fn = train_input_fn, max_steps = FLAGS.train_steps)
Beispiel #3
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]

    if FLAGS.save_steps == 0:
        FLAGS.save_steps = None

    if not FLAGS.do_eval_only:
        # Get train input function
        train_input_fn, train_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        train_bin_sizes = train_record_info["bin_sizes"]
        num_train_batch = train_record_info["num_batch"]

        # Get train cache function
        train_cache_fn = get_cache_fn(FLAGS.mem_len)
    else:
        train_bin_sizes = []
        num_train_batch = None
        train_cache_fn = None

    if FLAGS.do_eval or FLAGS.do_eval_only:
        assert FLAGS.num_hosts == 1
        # Get eval input function
        eval_input_fn, eval_record_info = data_utils.get_input_fn(
            record_info_dir=FLAGS.record_info_dir,
            split=FLAGS.eval_split,
            per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts,
            use_tpu=FLAGS.use_tpu)
        eval_bin_sizes = eval_record_info["bin_sizes"]
        num_eval_batch = eval_record_info["num_batch"]

        if FLAGS.max_eval_batch > 0:
            num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch)

        # Get eval cache function
        eval_cache_fn = get_cache_fn(FLAGS.mem_len)
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes,
                                eval_bin_sizes)
    else:
        eval_cache_fn = None
        model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, [])

    ##### Create estimator
    # TPU Configuration
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts,
            per_host_input_for_training=per_host_input),
        keep_checkpoint_max=100000,  # effectively save all checkpoints
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps)

    # warm start
    warm_start_from = None
    if FLAGS.warm_start_path is not None:
        warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.warm_start_path)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={
            "data_dir": FLAGS.data_dir,
            "track_mean": FLAGS.track_mean
        },
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        warm_start_from=warm_start_from)

    if FLAGS.do_eval_only:
        if FLAGS.eval_ckpt_path is not None:
            ret = estimator.evaluate(input_fn=eval_input_fn,
                                     steps=num_eval_batch,
                                     checkpoint_path=FLAGS.eval_ckpt_path)
            tf.logging.info("=" * 200)
            log_str = "Eval results | "
            for key, val in ret.items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
        else:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
            eval_results = []
            for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
                if not exists(eval_checkpoint + ".index"): continue
                global_step = int(eval_checkpoint.split("-")[-1])
                if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps:
                    continue
                ret = estimator.evaluate(input_fn=eval_input_fn,
                                         steps=num_eval_batch,
                                         checkpoint_path=eval_checkpoint)
                eval_results.append(ret)

            eval_results.sort(key=lambda x: x["perplexity"])

            tf.logging.info("=" * 200)
            log_str = "Best results | "
            for key, val in eval_results[0].items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
    else:
        if not FLAGS.do_eval:
            estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)
        else:
            for step in range(0, FLAGS.train_steps, num_train_batch):
                train_steps = min(FLAGS.train_steps - step, num_train_batch)
                estimator.train(input_fn=train_input_fn, steps=train_steps)
                estimator.evaluate(input_fn=eval_input_fn,
                                   steps=num_eval_batch)
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    assert FLAGS.seq_len > 0
    assert FLAGS.perm_size > 0

    FLAGS.batch_size = FLAGS.batch_size * FLAGS.num_hosts

    FLAGS.n_token = data_utils.VOCAB_SIZE
    tf.logging.info("n_token {}".format(FLAGS.n_token))

    if FLAGS.bucket_uri is not None:
        FLAGS.model_dir = os.path.join(FLAGS.bucket_uri, FLAGS.model_dir)

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

    # Get train input function
    train_input_fn, train_record_info_dict = get_input_fn("train")
    valid_input_fn, valid_record_info_dict = get_input_fn("valid")

    train_steps = train_record_info_dict["num_batch"]
    valid_steps = valid_record_info_dict["num_batch"]
    FLAGS.train_steps = train_steps
    FLAGS.save_steps = train_steps * FLAGS.epochs

    tf.logging.info("num of batches {}".format(
        train_record_info_dict["num_batch"]))

    # Get train cache function
    train_cache_fn = get_cache_fn(FLAGS.mem_len)
    eval_cache_fn = get_cache_fn(FLAGS.mem_len)

    ##### Get model function
    info_dict = {
        "id": FLAGS.run_id,
        "n_layers": FLAGS.n_layer,
        "d_model": FLAGS.d_model,
        "n_heads": FLAGS.n_head
    }
    _dir = get_logdir(os.path.join(FLAGS.bucket_uri, FLAGS.logDir), info_dict)
    model_fn = get_model_fn(_dir)

    ##### Create TPUEstimator
    # TPU Configuration
    run_config = model_utils.configure_tpu(FLAGS)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={"track_mean": FLAGS.track_mean},
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        eval_on_tpu=FLAGS.use_tpu)

    #### Training and Validation
    eval_errs = []
    xs = list(range(PATIENCE))
    train_times, eval_times = [], []
    stopped_early = False
    for i in range(FLAGS.epochs):

        if FLAGS.do_train:
            tf.logging.info("#### Starting training cycle")
            start = time.time()
            train_ret = estimator.train(input_fn=train_input_fn,
                                        steps=train_steps)
            end = time.time()
            train_times.append((end - start) / 60)
            tf.logging.info(
                "##################################### EPOCH {} #####################################"
                .format(i + 1))

        if FLAGS.do_eval:
            tf.logging.info("#### Starting evaluation/validation cycle")
            start = time.time()
            eval_ret = estimator.evaluate(input_fn=valid_input_fn,
                                          steps=valid_steps)
            end = time.time()
            eval_times.append((end - start) / 60)

        if FLAGS.do_early_stop:
            # Early Stopping based on gradient from last PATIENCE points
            eval_errs.append(eval_ret['avg_loss'])
            if len(eval_errs) > PATIENCE:
                last_errs = eval_errs[-PATIENCE:]
                slope = round(
                    np.polyfit(xs, last_errs, deg=1)[0], ROUNDING_PRECISION)
                if slope >= 0:
                    stopped_early = True
                    break

        if not FLAGS.do_train:
            break

    if FLAGS.do_save_results:
        best_loss = min(eval_errs)
        best_pplx = np.exp(best_loss)
        std = np.std(list(map(np.exp, eval_errs)))
        if last_errs is None:
            last_errs = []
            slope = 0
        result = {
            'loss': str(best_loss),
            'pplx': str(best_pplx),
            'std': str(std),
            'avg_train_time': str(np.mean(train_times)),
            'avg_eval_time': str(np.mean(eval_times)),
            'stopped_early': str(stopped_early),
            'last_errors': str(last_errs),
            'slope': str(slope),
            'epoch': str(i)
        }

        result = {'loss': eval_errs}
        with tf.gfile.Open(
                os.path.join(FLAGS.bucket_uri, "results",
                             "{}.json".format(FLAGS.run_id)), "w") as fp:
            json.dump(result, fp)