def single_worker_inference(infer_model, ckpt, inference_input_file, inference_context_file, inference_output_file, hparams): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) if inference_context_file is not None: infer_context = load_data(inference_context_file, hparams) else: infer_context = None infer_feed_dict = { infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } if infer_context is not None: infer_feed_dict[infer_model.ctx_placeholder] = infer_context with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict=infer_feed_dict) # Decode utils.print_out("# Start decoding") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, subword_option=hparams.subword_option) else: nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, hparams=hparams, num_translations_per_input=hparams.num_translations_per_input)
def __init__(self, hparams_path, source_prefix, out_dir, environment_server_address): """Constructor for the reformulator. Args: hparams_path: Path to json hparams file. source_prefix: A prefix that is added to every question before translation which should be used for adding tags like <en> <2en>. Can be empty or None in which case the prefix is ignored. out_dir: Directory where the model output will be written. environment_server_address: Address of the environment server. Raises: ValueError: if model architecture is not known. """ self.hparams = load_hparams(hparams_path, out_dir) assert self.hparams.num_buckets == 1, "No bucketing when in server mode." assert not self.hparams.server_mode, ( "server_mode set to True but not " "running as server.") self.hparams.environment_server = environment_server_address if self.hparams.subword_option == "spm": self.sentpiece = sentencepiece_processor.SentencePieceProcessor() self.sentpiece.Load(self.hparams.subword_model.encode("utf-8")) self.source_prefix = source_prefix # Create the model if not self.hparams.attention: model_creator = nmt_model.Model elif self.hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") self.trie = trie_decoder_utils.DecoderTrie( vocab_path=self.hparams.tgt_vocab_file, eos_token=self.hparams.eos, subword_option=self.hparams.subword_option, subword_model=self.hparams.get("subword_model"), optimize_ngrams_len=self.hparams.optimize_ngrams_len) if self.hparams.trie_path is not None and tf.gfile.Exists( self.hparams.trie_path): self.trie.populate_from_text_file(self.hparams.trie_path) combined_graph = tf.Graph() self.train_model = model_helper.create_train_model( model_creator, self.hparams, graph=combined_graph, trie=self.trie, use_placeholders=True) # Create different inference models for beam search, sampling and greedy # decoding. default_infer_mode = self.hparams.infer_mode default_beam_width = self.hparams.beam_width self.infer_models = {} self.hparams.use_rl = False self.hparams.infer_mode = "greedy" self.hparams.beam_width = 0 self.infer_models[reformulator_pb2.ReformulatorRequest. GREEDY] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph) self.hparams.infer_mode = "sample" self.hparams.beam_width = 0 self.infer_models[reformulator_pb2.ReformulatorRequest. SAMPLING] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph) self.hparams.infer_mode = "beam_search" self.hparams.beam_width = max(1, default_beam_width) self.infer_models[reformulator_pb2.ReformulatorRequest. BEAM_SEARCH] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph) self.hparams.infer_mode = "trie_greedy" self.hparams.beam_width = 0 self.infer_models[reformulator_pb2.ReformulatorRequest. TRIE_GREEDY] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph, trie=self.trie) self.hparams.infer_mode = default_infer_mode self.hparams.beam_width = default_beam_width self.hparams.infer_mode = "trie_sample" self.hparams.beam_width = 0 self.infer_models[reformulator_pb2.ReformulatorRequest. TRIE_SAMPLE] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph, trie=self.trie) self.hparams.infer_mode = "trie_beam_search" self.hparams.beam_width = max(1, default_beam_width) self.infer_models[reformulator_pb2.ReformulatorRequest. TRIE_BEAM_SEARCH] = model_helper.create_infer_model( model_creator, self.hparams, graph=combined_graph, trie=self.trie) self.hparams.use_rl = True self.sess = tf.Session(graph=combined_graph, config=misc_utils.get_config_proto()) with combined_graph.as_default(): # p1 = "C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/translate.ckpt-1460356" p1 = 'C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/temp/translate.ckpt-1460356' print('p1:', p1) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.tables_initializer()) _, global_step = model_helper.create_or_load_model( self.train_model.model, p1, self.sess, "train") # self.train_model.model, out_dir, self.sess, "train") self.last_save_step = global_step self.summary_writer = tf.summary.FileWriter( os.path.join(out_dir, "train_log"), self.train_model.graph) self.checkpoint_path = os.path.join(out_dir, "translate.ckpt") self.trie_save_path = os.path.join(out_dir, "trie")
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) combined_graph = tf.Graph() train_model = model_helper.create_train_model(model_creator, hparams, scope, graph=combined_graph) eval_model = model_helper.create_eval_model(model_creator, hparams, scope, graph=combined_graph) infer_model = model_helper.create_infer_model(model_creator, hparams, scope, graph=combined_graph) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_ctx_file = None if hparams.ctx is not None: dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) sample_ctx_data = None if dev_ctx_file is not None: sample_ctx_data = inference.load_data(dev_ctx_file) sample_annot_data = None if hparams.dev_annotations is not None: sample_annot_data = inference.load_data(hparams.dev_annotations) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) sess = tf.Session(target=target_session, config=config_proto, graph=combined_graph) with train_model.graph.as_default(): sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(infer_model, sess, eval_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) run_external_eval(infer_model, sess, hparams, summary_writer) sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) dev_ppl, test_ppl = None, None # only evaluate perplexity when supervised learning if not hparams.use_rl: dev_ppl, test_ppl = run_internal_eval(eval_model, sess, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) run_external_eval(infer_model, sess, hparams, summary_writer) # Done training loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = (run_full_eval( infer_model, sess, eval_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data)) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() return final_eval_metrics, global_step
def multi_worker_inference(infer_model, ckpt, inference_input_file, inference_context_file, inference_output_file, hparams, num_workers, jobid): """Inference using multiple workers.""" assert num_workers > 1 final_output_infer = inference_output_file output_infer = "%s_%d" % (inference_output_file, jobid) output_infer_done = "%s_done_%d" % (inference_output_file, jobid) # Read data infer_data = load_data(inference_input_file, hparams) if inference_context_file is not None: infer_context = load_data(inference_context_file, hparams) else: infer_context = None # Split data to multiple workers total_load = len(infer_data) load_per_worker = int((total_load - 1) / num_workers) + 1 start_position = jobid * load_per_worker end_position = min(start_position + load_per_worker, total_load) infer_data = infer_data[start_position:end_position] if infer_context is not None: infer_context = infer_context[start_position:end_position] infer_feed_dict = { infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } if infer_context is not None: infer_feed_dict[infer_model.ctx_placeholder] = infer_context with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict=infer_feed_dict) # Decode utils.print_out("# Start decoding") nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, hparams=hparams, num_translations_per_input=hparams.num_translations_per_input) # Change file name to indicate the file writing is completed. tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) # Job 0 is responsible for the clean up. if jobid != 0: return # Now write all translations with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) while not tf.gfile.Exists(worker_infer_done): utils.print_out(" waitting job %d to complete." % worker_id) time.sleep(10) with codecs.getreader("utf-8")(tf.gfile.GFile( worker_infer_done, mode="rb")) as f: for translation in f: final_f.write("%s" % translation) for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) tf.gfile.Remove(worker_infer_done)