def train_fn(hparams): """Train function.""" model_fn = make_model_fn(hparams) input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) log_step_count_steps = hparams.log_step_count_steps save_checkpoints_steps = hparams.save_checkpoints_steps # if hparams.use_dist_strategy: # distribution_strategy = get_distribution_strategy(hparams.num_gpus) # config = tf.estimator.RunConfig( # train_distribute=distribution_strategy, # log_step_count_steps=log_step_count_steps, # keep_checkpoint_max=None, # save_checkpoints_steps=save_checkpoints_steps) # else: sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.visible_device_list = str(hvd.local_rank()) if hparams.use_autojit_xla: sess_config.graph_options.optimizer_options.global_jit_level = ( tf.OptimizerOptions.ON_1) if not hparams.use_pintohost_optimizer: sess_config.graph_options.rewrite_options.pin_to_host_optimization = ( rewriter_config_pb2.RewriterConfig.OFF) config = tf.estimator.RunConfig( log_step_count_steps=log_step_count_steps, session_config=sess_config, keep_checkpoint_max=None, save_checkpoints_steps=save_checkpoints_steps) misc_utils.print_out("sess master is %s" % config.master) estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=hparams.output_dir, config=config) benchmark_hook = BenchmarkHook(hparams.batch_size, hparams.warmup_steps + 5) train_hooks = [hvd.BroadcastGlobalVariablesHook(0), benchmark_hook] if hparams.profile: train_hooks.append( tf.train.ProfilerHook(output_dir=hparams.output_dir, save_steps=hparams.profile_save_steps, show_dataflow=True, show_memory=True)) max_steps = hparams.debug_num_train_steps global_step = estimator.train( input_fn=input_fn, max_steps=max_steps, hooks=train_hooks, ).get_variable_value(tf.GraphKeys.GLOBAL_STEP) sentences_sec, latencies = benchmark_hook.get_average_speed_and_latencies() return global_step, sentences_sec, latencies
def train_fn(hparams): """Train function.""" model_fn = make_model_fn(hparams) input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) log_step_count_steps = hparams.log_step_count_steps save_checkpoints_steps = hparams.save_checkpoints_steps if hparams.use_dist_strategy: distribution_strategy = get_distribution_strategy(hparams.num_gpus) config = tf.estimator.RunConfig( train_distribute=distribution_strategy, log_step_count_steps=log_step_count_steps, keep_checkpoint_max=None, save_checkpoints_steps=save_checkpoints_steps) else: sess_config = tf.ConfigProto(allow_soft_placement=True) if hparams.use_autojit_xla: sess_config.graph_options.optimizer_options.global_jit_level = ( tf.OptimizerOptions.ON_1) if not hparams.use_pintohost_optimizer: sess_config.graph_options.rewrite_options.pin_to_host_optimization = ( rewriter_config_pb2.RewriterConfig.OFF) config = tf.estimator.RunConfig( log_step_count_steps=log_step_count_steps, session_config=sess_config, keep_checkpoint_max=None, save_checkpoints_steps=save_checkpoints_steps) misc_utils.print_out("sess master is %s" % config.master) estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=hparams.output_dir, config=config) benchmark_hook = BenchmarkHook(hparams.batch_size, hparams.warmup_steps + 5) train_hooks = [benchmark_hook] if hparams.profile: train_hooks.append( tf.train.ProfilerHook(output_dir=hparams.output_dir, save_steps=hparams.profile_save_steps, show_dataflow=True, show_memory=True)) max_steps = hparams.debug_num_train_steps estimator.train( input_fn=input_fn, max_steps=max_steps, hooks=train_hooks, ) return benchmark_hook.get_average_speed_and_latencies()
def get_metrics(hparams, model_fn, ckpt=None, only_translate=False): """Run inference and compute metrics.""" pred_estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=hparams.output_dir) benchmark_hook = BenchmarkHook(hparams.infer_batch_size) predictions = pred_estimator.predict(make_input_fn( hparams, tf.contrib.learn.ModeKeys.INFER), checkpoint_path=ckpt, hooks=[benchmark_hook]) translations = [] output_tokens = [] beam_id = 0 for prediction in predictions: # get the top translation. if beam_id == 0: for sent_id in range(hparams.infer_batch_size): if sent_id >= prediction["predictions"].shape[0]: break trans, output_length = nmt_utils.get_translation( prediction["predictions"], sent_id=sent_id, tgt_eos=hparams.eos, subword_option=hparams.subword_option) translations.append(trans) output_tokens.append(output_length) beam_id += 1 if beam_id == hparams.beam_width: beam_id = 0 if only_translate: trans_file = hparams.translate_file + '.trans.tok' else: trans_file = os.path.join( hparams.output_dir, "newstest2014_out_{}.tok.de".format( pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))) trans_dir = os.path.dirname(trans_file) if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) tf.logging.info("Writing to file %s" % trans_file) with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file, mode="wb")) as trans_f: trans_f.write("") # Write empty string to ensure file is created. for translation in translations: trans_f.write((translation + b"\n").decode("utf-8")) if only_translate: return None, benchmark_hook.get_average_speed_and_latencies(), sum( output_tokens) # Evaluation output_dir = os.path.join(pred_estimator.model_dir, "eval") tf.gfile.MakeDirs(output_dir) summary_writer = tf.summary.FileWriter(output_dir) ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) # Hardcoded. metric = "bleu" score = get_sacrebleu(trans_file, hparams.detokenizer_file) misc_utils.print_out("bleu is %.5f" % score) with tf.Graph().as_default(): summaries = [] summaries.append(tf.Summary.Value(tag=metric, simple_value=score)) tf_summary = tf.Summary(value=list(summaries)) summary_writer.add_summary( tf_summary, pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)) summary_writer.close() return score, benchmark_hook.get_average_speed_and_latencies(), sum( output_tokens)