def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) assert FLAGS.seq_len > 0 assert FLAGS.perm_size > 0 FLAGS.n_token = data_utils.VOCAB_SIZE tf.logging.info("n_token {}".format(FLAGS.n_token)) if not tf.gfile.Exists(FLAGS.model_dir): tf.gfile.MakeDirs(FLAGS.model_dir) # Get train input function train_input_fn, train_record_info_dict = get_input_fn("train") tf.logging.info("num of batches {}".format( train_record_info_dict["num_batch"])) # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) ##### Get model function model_fn = get_model_fn() ##### Create TPUEstimator # TPU Configuration run_config = model_utils.configure_tpu(FLAGS) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, train_cache_fn=train_cache_fn, use_tpu=FLAGS.use_tpu, config=run_config, params={"track_mean": FLAGS.track_mean}, train_batch_size=FLAGS.train_batch_size, eval_on_tpu=FLAGS.use_tpu) hooks = None if FLAGS.debug: if FLAGS.debug_dump_dir: hooks = [tf_debug.DumpingDebugHook(FLAGS.debug_dump_dir)] else: hooks = [tf_debug.LocalCLIDebugHook()] #### Training estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps, hooks=hooks)
def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) assert FLAGS.seq_len > 0 assert FLAGS.perm_size > 0 FLAGS.n_token = data_utils.VOCAB_SIZE tf.logging.info('n_token {}'.format(FLAGS.n_token)) if not tf.gfile.Exists(FLAGS.model_dir): tf.gfile.MakeDirs(FLAGS.model_dir) # Get train input function train_input_fn, train_record_info_dict = get_input_fn('train') tf.logging.info( 'num of batches {}'.format(train_record_info_dict['num_batch']) ) # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) ##### Get model function model_fn = get_model_fn() ##### Create TPUEstimator # TPU Configuration run_config = model_utils.configure_tpu(FLAGS) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn = model_fn, train_cache_fn = train_cache_fn, use_tpu = FLAGS.use_tpu, config = run_config, params = {'track_mean': FLAGS.track_mean}, train_batch_size = FLAGS.train_batch_size, eval_on_tpu = FLAGS.use_tpu, ) #### Training estimator.train(input_fn = train_input_fn, max_steps = FLAGS.train_steps)
def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) # Get corpus info corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path) n_token = corpus_info["vocab_size"] cutoffs = corpus_info["cutoffs"][1:-1] if FLAGS.save_steps == 0: FLAGS.save_steps = None if not FLAGS.do_eval_only: # Get train input function train_input_fn, train_record_info = data_utils.get_input_fn( record_info_dir=FLAGS.record_info_dir, split="train", per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts, tgt_len=FLAGS.tgt_len, num_core_per_host=FLAGS.num_core_per_host, num_hosts=FLAGS.num_hosts, use_tpu=FLAGS.use_tpu) train_bin_sizes = train_record_info["bin_sizes"] num_train_batch = train_record_info["num_batch"] # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) else: train_bin_sizes = [] num_train_batch = None train_cache_fn = None if FLAGS.do_eval or FLAGS.do_eval_only: assert FLAGS.num_hosts == 1 # Get eval input function eval_input_fn, eval_record_info = data_utils.get_input_fn( record_info_dir=FLAGS.record_info_dir, split=FLAGS.eval_split, per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts, tgt_len=FLAGS.tgt_len, num_core_per_host=FLAGS.num_core_per_host, num_hosts=FLAGS.num_hosts, use_tpu=FLAGS.use_tpu) eval_bin_sizes = eval_record_info["bin_sizes"] num_eval_batch = eval_record_info["num_batch"] if FLAGS.max_eval_batch > 0: num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch) # Get eval cache function eval_cache_fn = get_cache_fn(FLAGS.mem_len) model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes) else: eval_cache_fn = None model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, []) ##### Create estimator # TPU Configuration tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True), tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations, num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts, per_host_input_for_training=per_host_input), keep_checkpoint_max=100000, # effectively save all checkpoints save_checkpoints_secs=None, save_checkpoints_steps=FLAGS.save_steps) # warm start warm_start_from = None if FLAGS.warm_start_path is not None: warm_start_from = tf.estimator.WarmStartSettings( ckpt_to_initialize_from=FLAGS.warm_start_path) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, train_cache_fn=train_cache_fn, eval_cache_fn=eval_cache_fn, use_tpu=FLAGS.use_tpu, config=run_config, params={ "data_dir": FLAGS.data_dir, "track_mean": FLAGS.track_mean }, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, warm_start_from=warm_start_from) if FLAGS.do_eval_only: if FLAGS.eval_ckpt_path is not None: ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, checkpoint_path=FLAGS.eval_ckpt_path) tf.logging.info("=" * 200) log_str = "Eval results | " for key, val in ret.items(): log_str += "{} {} | ".format(key, val) tf.logging.info(log_str) tf.logging.info("=" * 200) else: ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir) eval_results = [] for eval_checkpoint in ckpt_state.all_model_checkpoint_paths: if not exists(eval_checkpoint + ".index"): continue global_step = int(eval_checkpoint.split("-")[-1]) if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps: continue ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, checkpoint_path=eval_checkpoint) eval_results.append(ret) eval_results.sort(key=lambda x: x["perplexity"]) tf.logging.info("=" * 200) log_str = "Best results | " for key, val in eval_results[0].items(): log_str += "{} {} | ".format(key, val) tf.logging.info(log_str) tf.logging.info("=" * 200) else: if not FLAGS.do_eval: estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps) else: for step in range(0, FLAGS.train_steps, num_train_batch): train_steps = min(FLAGS.train_steps - step, num_train_batch) estimator.train(input_fn=train_input_fn, steps=train_steps) estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch)
def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) assert FLAGS.seq_len > 0 assert FLAGS.perm_size > 0 FLAGS.batch_size = FLAGS.batch_size * FLAGS.num_hosts FLAGS.n_token = data_utils.VOCAB_SIZE tf.logging.info("n_token {}".format(FLAGS.n_token)) if FLAGS.bucket_uri is not None: FLAGS.model_dir = os.path.join(FLAGS.bucket_uri, FLAGS.model_dir) if not tf.gfile.Exists(FLAGS.model_dir): tf.gfile.MakeDirs(FLAGS.model_dir) # Get train input function train_input_fn, train_record_info_dict = get_input_fn("train") valid_input_fn, valid_record_info_dict = get_input_fn("valid") train_steps = train_record_info_dict["num_batch"] valid_steps = valid_record_info_dict["num_batch"] FLAGS.train_steps = train_steps FLAGS.save_steps = train_steps * FLAGS.epochs tf.logging.info("num of batches {}".format( train_record_info_dict["num_batch"])) # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) eval_cache_fn = get_cache_fn(FLAGS.mem_len) ##### Get model function info_dict = { "id": FLAGS.run_id, "n_layers": FLAGS.n_layer, "d_model": FLAGS.d_model, "n_heads": FLAGS.n_head } _dir = get_logdir(os.path.join(FLAGS.bucket_uri, FLAGS.logDir), info_dict) model_fn = get_model_fn(_dir) ##### Create TPUEstimator # TPU Configuration run_config = model_utils.configure_tpu(FLAGS) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, train_cache_fn=train_cache_fn, eval_cache_fn=eval_cache_fn, use_tpu=FLAGS.use_tpu, config=run_config, params={"track_mean": FLAGS.track_mean}, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size, eval_on_tpu=FLAGS.use_tpu) #### Training and Validation eval_errs = [] xs = list(range(PATIENCE)) train_times, eval_times = [], [] stopped_early = False for i in range(FLAGS.epochs): if FLAGS.do_train: tf.logging.info("#### Starting training cycle") start = time.time() train_ret = estimator.train(input_fn=train_input_fn, steps=train_steps) end = time.time() train_times.append((end - start) / 60) tf.logging.info( "##################################### EPOCH {} #####################################" .format(i + 1)) if FLAGS.do_eval: tf.logging.info("#### Starting evaluation/validation cycle") start = time.time() eval_ret = estimator.evaluate(input_fn=valid_input_fn, steps=valid_steps) end = time.time() eval_times.append((end - start) / 60) if FLAGS.do_early_stop: # Early Stopping based on gradient from last PATIENCE points eval_errs.append(eval_ret['avg_loss']) if len(eval_errs) > PATIENCE: last_errs = eval_errs[-PATIENCE:] slope = round( np.polyfit(xs, last_errs, deg=1)[0], ROUNDING_PRECISION) if slope >= 0: stopped_early = True break if not FLAGS.do_train: break if FLAGS.do_save_results: best_loss = min(eval_errs) best_pplx = np.exp(best_loss) std = np.std(list(map(np.exp, eval_errs))) if last_errs is None: last_errs = [] slope = 0 result = { 'loss': str(best_loss), 'pplx': str(best_pplx), 'std': str(std), 'avg_train_time': str(np.mean(train_times)), 'avg_eval_time': str(np.mean(eval_times)), 'stopped_early': str(stopped_early), 'last_errors': str(last_errs), 'slope': str(slope), 'epoch': str(i) } result = {'loss': eval_errs} with tf.gfile.Open( os.path.join(FLAGS.bucket_uri, "results", "{}.json".format(FLAGS.run_id)), "w") as fp: json.dump(result, fp)