def get_learning_rate_warmup(self): """ Get learning rate warmup. """ hparams = self.hparams warmup_steps = hparams.warmup_steps warmup_scheme = hparams.warmup_scheme utils.log("learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % (hparams.learning_rate, warmup_steps, warmup_scheme)) # Apply inverse decay if global steps less than warmup steps. # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) # When step < warmup_steps, # learing_rate *= warmup_factor ** (warmup_steps - step) if warmup_scheme == "t2t": # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step)) else: raise ValueError("Unknown warmup scheme %s" % warmup_scheme) return tf.cond(self.global_step < hparams.warmup_steps, lambda: inv_decay * self.learning_rate, lambda: self.learning_rate, name="learning_rate_warump_cond")
def infer(ckpt, inference_input_file, inference_output_file, hparams): """ Perform translation. """ model_creator = gnmt_model.GNMTModel infer_model = model_helper.create_infer_model(model_creator, hparams) # Read data infer_data = utils.load_data(inference_input_file) config_proto = tf.ConfigProto() config_proto.gpu_options.allow_growth = True with tf.Session( graph=infer_model.graph, config=config_proto) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run( infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.log("Start decoding") loaded_infer_model.decode_and_evaluate( "infer", sess, inference_output_file, ref_file=None, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input)
def run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, save_best_dev=True, use_test_set=True, avg_ckpts=False): """ Compute external evaluation (bleu, rouge, etc.) for both dev / test. """ with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") dev_scores = None test_scores = None if global_step > 0: utils.log("External evaluation, global step %d" % global_step) dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_infer_iterator_feed_dict = { infer_model.src_placeholder: utils.load_data(dev_src_file), infer_model.batch_size_placeholder: hparams.infer_batch_size, } dev_scores = external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, dev_infer_iterator_feed_dict, dev_tgt_file, "dev", summary_writer, save_on_best=save_best_dev, avg_ckpts=avg_ckpts) if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) test_infer_iterator_feed_dict = { infer_model.src_placeholder: utils.load_data(test_src_file), infer_model.batch_size_placeholder: hparams.infer_batch_size, } test_scores = external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, test_infer_iterator_feed_dict, test_tgt_file, "test", summary_writer, save_on_best=False, avg_ckpts=avg_ckpts) return dev_scores, test_scores, global_step
def print_step_info(prefix, global_step, info, result_summary): """ Print all info at the current global step. """ utils.log("%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" % (prefix, global_step, info["learning_rate"], info["avg_step_time"], info["speed"], info["train_ppl"], info["avg_grad_norm"], result_summary))
def build_graph(self, scope): utils.log('Creating {} graph ...'.format(self.mode)) dtype = tf.float32 with tf.variable_scope(scope or "gnmt", dtype=dtype): self.build_encoder() self.build_decoder() if self.mode != tf.contrib.learn.ModeKeys.INFER: self.compute_loss() else: self.loss = None
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model.saver.restore(session, latest_ckpt) session.run(tf.tables_initializer()) utils.log("Load {} model parameters from {}".format(name, latest_ckpt)) else: session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) utils.log("Create {} model with fresh parameters".format(name)) global_step = model.global_step.eval(session=session) return model, global_step
def process_stats(stats, info, global_step, steps_per_stats): """ Update info and check for overflow. """ # Update info info["avg_step_time"] = stats["step_time"] / steps_per_stats info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats info["train_ppl"] = utils.safe_exp(stats["loss"] / stats["predict_count"]) info["speed"] = stats["total_count"] / (1000 * stats["step_time"]) # Check for overflow is_overflow = False train_ppl = info["train_ppl"] if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: utils.log("step %d overflow, stop early" % (global_step, )) is_overflow = True return is_overflow
def get_learning_rate_decay(self): """ Get learning rate decay. """ hparams = self.hparams if hparams.decay_scheme in ["luong5", "luong10", "luong234"]: decay_factor = 0.5 if hparams.decay_scheme == "luong5": start_decay_step = int(hparams.num_train_steps / 2) decay_times = 5 elif hparams.decay_scheme == "luong10": start_decay_step = int(hparams.num_train_steps / 2) decay_times = 10 elif hparams.decay_scheme == "luong234": start_decay_step = int(hparams.num_train_steps * 2 / 3) decay_times = 4 remain_steps = hparams.num_train_steps - start_decay_step decay_steps = int(remain_steps / decay_times) elif not hparams.decay_scheme: # no decay start_decay_step = hparams.num_train_steps decay_steps = 0 decay_factor = 1.0 elif hparams.decay_scheme: raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) utils.log("decay_scheme=%s, start_decay_step=%d, decay_steps %d, " "decay_factor %g" % (hparams.decay_scheme, start_decay_step, decay_steps, decay_factor)) if hparams.decay_scheme in ["luong5", "luong10", "luong234"]: return tf.cond(self.global_step < start_decay_step, lambda: self.learning_rate, lambda: tf.train.exponential_decay( self.learning_rate, (self.global_step - start_decay_step), decay_steps, decay_factor, staircase=True), name="learning_rate_decay_cond") elif not hparams.decay_scheme: return self.learning_rate
def compute_perplexity(self, sess, name): """ Compute perplexity of the output of the model. """ total_loss = 0 total_predict_count = 0 while True: try: loss, predict_count, batch_size = self.eval(sess) total_loss += loss * batch_size total_predict_count += predict_count except tf.errors.OutOfRangeError: break perplexity = utils.safe_exp(total_loss / total_predict_count) utils.log("%s perplexity: %.2f" % (name, perplexity)) return perplexity
def run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer, use_test_set=True): """ Compute internal evaluation (perplexity) for both dev / test. """ with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") utils.log("Internal evaluation, global step %d" % global_step) dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_eval_iterator_feed_dict = { eval_model.src_file_placeholder: dev_src_file, eval_model.tgt_file_placeholder: dev_tgt_file } dev_ppl = internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, dev_eval_iterator_feed_dict, summary_writer, "dev") test_ppl = None if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) test_eval_iterator_feed_dict = { eval_model.src_file_placeholder: test_src_file, eval_model.tgt_file_placeholder: test_tgt_file } test_ppl = internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, test_eval_iterator_feed_dict, summary_writer, "test") return dev_ppl, test_ppl
def create_pretrained_emb_from_txt(self, vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, scope=None): """ Load pretrain embeding from embed_file, and return an embedding matrix. """ vocab, _ = vocab_utils.load_vocab(vocab_file) trainable_tokens = vocab[:num_trainable_tokens] utils.log("Using pretrained embedding: {}".format(embed_file)) utils.log("with trainable tokens: ") emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) for token in trainable_tokens: utils.log("{}".format(token)) if token not in emb_dict: emb_dict[token] = [0.0] * emb_size emb_mat = np.array([emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) emb_mat = tf.constant(emb_mat) emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: with tf.device(self.get_embed_device(num_trainable_tokens)): emb_mat_var = tf.get_variable("emb_mat_var", [num_trainable_tokens, emb_size]) return tf.concat([emb_mat_var, emb_mat_const], 0)
def init_embeddings(self, hparams, scope, dtype=tf.float32): src_vocab_size = self.src_vocab_size tgt_vocab_size = self.tgt_vocab_size src_embed_size = self.src_embed_size tgt_embed_size = self.tgt_embed_size src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_embed_file = hparams.src_embed_file tgt_embed_file = hparams.tgt_embed_file with tf.variable_scope(scope or "embeddings", dtype=dtype) as scope: # Share embedding if hparams.share_vocab: if src_vocab_size != tgt_vocab_size: raise ValueError( "Share embedding but different src/tgt vocab sizes" " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) assert src_embed_size == tgt_embed_size utils.log("Use the same embedding for source and target") vocab_file = src_vocab_file or tgt_vocab_file embed_file = src_embed_file or tgt_embed_file self.embedding_encoder = self.create_or_load_embed( "embedding_share", vocab_file, embed_file, src_vocab_size, src_embed_size, dtype) self.embedding_decoder = self.embedding_encoder else: with tf.variable_scope("encoder"): self.embedding_encoder = self.create_or_load_embed( "embedding_encoder", src_vocab_file, src_embed_file, src_vocab_size, src_embed_size, dtype) with tf.variable_scope("decoder"): self.embedding_decoder = self.create_or_load_embed( "embedding_decoder", tgt_vocab_file, tgt_embed_file, tgt_vocab_size, tgt_embed_size, dtype)
def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, eos=None, unk=None): """ Check if vocab_file doesn't exist, create from corpus_file. """ if os.path.exists(vocab_file): utils.log("Vocab file %s exists" % vocab_file) vocab, vocab_size = load_vocab(vocab_file) if check_special_token: # Verify if the vocab starts with unk, sos, eos # If not, prepend those tokens & generate a new vocab file if not unk: unk = UNK if not sos: sos = SOS if not eos: eos = EOS assert len(vocab) >= 3 if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: utils.log("The first 3 vocab words [%s, %s, %s]" " are not [%s, %s, %s]" % (vocab[0], vocab[1], vocab[2], unk, sos, eos)) vocab = [unk, sos, eos] + vocab vocab_size += 3 new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) with open(new_vocab_file, "w", encoding='utf-8') as f: for word in vocab: f.write("%s\n" % (word, )) vocab_file = new_vocab_file else: raise ValueError("vocab_file '%s' does not exist." % (vocab_file, )) vocab_size = len(vocab) return vocab_size, vocab_file
def decode_and_evaluate(self, name, sess, trans_file, ref_file, beam_width, tgt_eos, num_translations_per_input=1): """ Decode a test set and compute a score according to the evaluation task. """ # Decode utils.log("Decoding to output {}.".format(trans_file)) num_sentences = 0 with open(trans_file, 'w', encoding='utf-8') as trans_f: trans_f.write("") # Write empty string to ensure file is created. num_translations_per_input = max( min(num_translations_per_input, beam_width), 1) while True: try: nmt_outputs, _ = self.decode(sess) if beam_width == 0: nmt_outputs = np.expand_dims(nmt_outputs, 0) batch_size = nmt_outputs.shape[1] num_sentences += batch_size for sent_id in range(batch_size): for beam_id in range(num_translations_per_input): translation = utils.get_translation( nmt_outputs[beam_id], sent_id, tgt_eos=tgt_eos) trans_f.write(translation + "\n") except tf.errors.OutOfRangeError: utils.log( "Done, num sentences %d, num translations per input %d" % (num_sentences, num_translations_per_input)) break # Evaluation evaluation_scores = {} if ref_file and os.path.exists(trans_file): score = evaluation_utils.evaluate(ref_file, trans_file, 'BLEU') evaluation_scores['BLEU'] = score utils.log("%s BLEU: %.1f" % (name, score)) return evaluation_scores
def run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, src_data, tgt_data): """ Sample decode a random sentence from src_data. """ with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") # Pick a sentence and decode.""" decode_id = random.randint(0, len(src_data) - 1) iterator_feed_dict = { infer_model.src_placeholder: [src_data[decode_id]], infer_model.batch_size_placeholder: 1, } infer_sess.run(infer_model.iterator.initializer, feed_dict=iterator_feed_dict) nmt_outputs, attention_summary = loaded_infer_model.decode(infer_sess) if hparams.beam_width > 0: # get the top translation. nmt_outputs = nmt_outputs[0] translation = utils.get_translation(nmt_outputs, sent_id=0, tgt_eos=hparams.eos) utils.log("Sample src: {}".format(src_data[decode_id])) utils.log("Sample ref: {}".format(tgt_data[decode_id])) utils.log("NMT output: {}".format(translation)) # Summary if attention_summary is not None: summary_writer.add_summary(attention_summary, global_step)
def __init__(self, hparams, mode, iterator, source_vocab_table, target_vocab_table, reverse_target_vocab_table=None, scope=None): assert isinstance(iterator, iterator_utils.BatchedInput) self.hparams = hparams self.iterator = iterator self.mode = mode self.src_vocab_table = source_vocab_table self.tgt_vocab_table = target_vocab_table self.src_vocab_size = hparams.src_vocab_size self.tgt_vocab_size = hparams.tgt_vocab_size self.src_embed_size = hparams.embed_size self.tgt_embed_size = hparams.embed_size self.num_encoder_layers = hparams.num_encoder_layers self.num_decoder_layers = hparams.num_decoder_layers assert self.num_encoder_layers assert self.num_decoder_layers self.num_encoder_residual_layers = hparams.num_encoder_residual_layers self.num_decoder_residual_layers = hparams.num_decoder_residual_layers self.batch_size = tf.size(self.iterator.source_sequence_length) # Initializer initializer = self.get_initializer(hparams.init_op, hparams.random_seed, hparams.init_weight) tf.get_variable_scope().set_initializer(initializer) # Embeddings self.init_embeddings(hparams, scope) # Projection with tf.variable_scope(scope or "build_network"): with tf.variable_scope("decoder/output_projection"): self.output_layer = tf.layers.Dense(self.tgt_vocab_size, use_bias=False, name="output_projection") self.build_graph(scope) if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.train_loss = self.loss self.word_count = tf.reduce_sum( self.iterator.source_sequence_length) + tf.reduce_sum( self.iterator.target_sequence_length) elif self.mode == tf.contrib.learn.ModeKeys.EVAL: self.eval_loss = self.loss elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_logits = self.logits self.sample_words = reverse_target_vocab_table.lookup( tf.to_int64(self.sample_id)) if self.mode != tf.contrib.learn.ModeKeys.INFER: ## Count the number of predicted words for compute ppl. self.predict_count = tf.reduce_sum( self.iterator.target_sequence_length) self.global_step = tf.Variable(0, trainable=False) params = tf.trainable_variables() # Gradients and SGD update operation for training the model. # Arrage for the embedding vars to appear at the beginning. if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.learning_rate = tf.constant(hparams.learning_rate) # warm-up self.learning_rate = self.get_learning_rate_warmup() # decay self.learning_rate = self.get_learning_rate_decay() # Optimizer if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) # Gradients gradients = tf.gradients(self.train_loss, params) clipped_gradients, self.grad_norm = tf.clip_by_global_norm( gradients, hparams.max_gradient_norm) gradient_norm_summary = [ tf.summary.scalar("grad_norm", self.grad_norm) ] gradient_norm_summary.append( tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) self.update = opt.apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) # Summary self.train_summary = tf.summary.merge([ tf.summary.scalar("lr", self.learning_rate), tf.summary.scalar("train_loss", self.train_loss), ] + gradient_norm_summary) if self.mode == tf.contrib.learn.ModeKeys.INFER: if hparams.beam_width > 0: self.infer_summary = tf.no_op() else: attention_images = ( self.final_context_state[0].alignment_history.stack()) # Reshape to (batch, src_seq_len, tgt_seq_len,1) attention_images = tf.expand_dims( tf.transpose(attention_images, [1, 2, 0]), -1) # Scale to range [0, 255] attention_images *= 255 self.infer_summary = tf.summary.image("attention_images", attention_images) # Saver self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) # Print trainable variables utils.log("Trainable variables") for param in params: utils.log("%s, %s" % (param.name, str(param.get_shape())))
def run(args): if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) logger = logging.getLogger("nmt_zh") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler(os.path.join(args.out_dir, "log")) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) default_hparams = create_hparams(args) # Load hparams. hparams = create_or_load_hparams(default_hparams.out_dir, default_hparams) utils.log('Running with hparams : {}'.format(hparams)) random_seed = hparams.random_seed if random_seed is not None and random_seed > 0: utils.log('Set random seed to {}'.format(random_seed)) random.seed(random_seed) np.random.seed(random_seed) tf.set_random_seed(random_seed) if hparams.inference_input_file: utils.log('Inferring ...') # infer trans_file = hparams.inference_output_file ckpt = hparams.ckpt if not ckpt: ckpt = tf.train.latest_checkpoint(hparams.out_dir) utils.log('Use checkpoint: {}'.format(ckpt)) utils.log('Start infer sentence in {}, output saved to {} ...'.format( hparams.inference_input_file, trans_file)) infer.infer(ckpt, hparams.inference_input_file, trans_file, hparams) # eval ref_file = hparams.inference_ref_file if ref_file and os.path.exists(trans_file): utils.log( 'Evaluating infer output with reference in {} ...'.format( ref_file)) score = evaluation_utils.evaluate(ref_file, trans_file, 'BLEU') utils.log("BLEU: %.1f" % (score, )) else: utils.log('Training ...') train.train(hparams)
def create_or_load_hparams(out_dir, default_hparams): """ Create hparams or load hparams from out_dir. """ hparams = utils.load_hparams(out_dir) if not hparams: hparams = default_hparams hparams.add_hparam("best_bleu", 0) best_bleu_dir = os.path.join(out_dir, "best_bleu") hparams.add_hparam("best_bleu_dir", best_bleu_dir) os.makedirs(best_bleu_dir) hparams.add_hparam("avg_best_bleu", 0) best_bleu_dir = os.path.join(hparams.out_dir, "avg_best_bleu") hparams.add_hparam("avg_best_bleu_dir", os.path.join(hparams.out_dir, "avg_best_bleu")) os.makedirs(best_bleu_dir) # Set num_train_steps train_src_file = "%s.%s" % (hparams.train_prefix, hparams.src) train_tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) with open(train_src_file, 'r', encoding='utf-8') as f: train_src_steps = len(f.readlines()) with open(train_tgt_file, 'r', encoding='utf-8') as f: train_tgt_steps = len(f.readlines()) hparams.add_hparam( "num_train_steps", min([train_src_steps, train_tgt_steps]) * hparams.epochs) # Set encoder/decoder layers hparams.add_hparam("num_encoder_layers", hparams.num_layers) hparams.add_hparam("num_decoder_layers", hparams.num_layers) # Set residual layers num_encoder_residual_layers = 0 num_decoder_residual_layers = 0 if hparams.num_encoder_layers > 1: num_encoder_residual_layers = hparams.num_encoder_layers - 1 if hparams.num_decoder_layers > 1: num_decoder_residual_layers = hparams.num_decoder_layers - 1 # The first unidirectional layer (after the bi-directional layer) in # the GNMT encoder can't have residual connection due to the input is # the concatenation of fw_cell and bw_cell's outputs. num_encoder_residual_layers = hparams.num_encoder_layers - 2 # Compatible for GNMT models if hparams.num_encoder_layers == hparams.num_decoder_layers: num_decoder_residual_layers = num_encoder_residual_layers hparams.add_hparam("num_encoder_residual_layers", num_encoder_residual_layers) hparams.add_hparam("num_decoder_residual_layers", num_decoder_residual_layers) # Vocab # Get vocab file names first if hparams.vocab_prefix: src_vocab_file = hparams.vocab_prefix + "." + hparams.src tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt else: raise ValueError("hparams.vocab_prefix must be provided.") # Source vocab src_vocab_size, src_vocab_file = vocab_utils.check_vocab( src_vocab_file, hparams.out_dir, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) # Target vocab if hparams.share_vocab: utils.log("Using source vocab for target") tgt_vocab_file = src_vocab_file tgt_vocab_size = src_vocab_size else: tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( tgt_vocab_file, hparams.out_dir, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) hparams.add_hparam("src_vocab_size", src_vocab_size) hparams.add_hparam("tgt_vocab_size", tgt_vocab_size) hparams.add_hparam("src_vocab_file", src_vocab_file) hparams.add_hparam("tgt_vocab_file", tgt_vocab_file) # Pretrained Embeddings: hparams.add_hparam("src_embed_file", "") hparams.add_hparam("tgt_embed_file", "") if hparams.embed_prefix: src_embed_file = hparams.embed_prefix + "." + hparams.src tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt if os.path.exists(src_embed_file): hparams.src_embed_file = src_embed_file if os.path.exists(tgt_embed_file): hparams.tgt_embed_file = tgt_embed_file # Save HParams utils.save_hparams(out_dir, hparams) return hparams
def train(hparams, scope=None): model_dir = hparams.out_dir avg_ckpts = hparams.avg_ckpts steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval summary_name = "summary" model_creator = gnmt_model.GNMTModel train_model = model_helper.create_train_model(model_creator, hparams) eval_model = model_helper.create_eval_model(model_creator, hparams) infer_model = model_helper.create_infer_model(model_creator, hparams) config_proto = tf.ConfigProto() config_proto.gpu_options.allow_growth = True train_sess = tf.Session(graph=train_model.graph, config=config_proto) eval_sess = tf.Session(graph=eval_model.graph, config=config_proto) infer_sess = tf.Session(graph=infer_model.graph, config=config_proto) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(model_dir, summary_name), train_model.graph) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = utils.load_data(dev_src_file) sample_tgt_data = utils.load_data(dev_tgt_file) # First evaluation result_summary, _, _ = run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) utils.log('First evaluation: {}'.format(result_summary)) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() info = { "train_ppl": 0.0, "speed": 0.0, "avg_step_time": 0.0, "avg_grad_norm": 0.0, "learning_rate": loaded_train_model.learning_rate.eval(session=train_sess) } utils.log("Start step %d, lr %g" % (global_step, info["learning_rate"])) # Initialize all of the iterators train_sess.run(train_model.iterator.initializer) epoch = 1 while True: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. utils.log( "Finished epoch %d, step %d. Perform external evaluation" % (epoch, global_step)) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer) if epoch < hparams.epochs: epoch += 1 continue else: break # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats) print_step_info(" ", global_step, info, "BLEU %.2f" % (hparams.best_bleu, )) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.log("Save eval, global step %d" % (global_step, )) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) print_step_info("Final, ", global_step, info, result_summary) utils.log("Done training!") summary_writer.close() utils.log("Start evaluating saved best models.") best_model_dir = hparams.best_bleu_dir summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("Best BLEU, ", best_global_step, info, result_summary) summary_writer.close() if avg_ckpts: best_model_dir = hparams.avg_best_bleu_dir summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("Averaged Best BLEU, ", best_global_step, info, result_summary) summary_writer.close() return final_eval_metrics, global_step
def load_model(model, ckpt, session, name): model.saver.restore(session, ckpt) session.run(tf.tables_initializer()) utils.log("Load {} model parameters from {}".format(name, ckpt)) return model
def avg_checkpoints(model_dir, num_last_checkpoints, global_step, global_step_name): """ Average the last N checkpoints in the model_dir. """ checkpoint_state = tf.train.get_checkpoint_state(model_dir) if not checkpoint_state: utils.log("No checkpoint file found in directory: {}".format(model_dir)) return None # Checkpoints are ordered from oldest to newest. checkpoints = ( checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) if len(checkpoints) < num_last_checkpoints: utils.log( "Skipping averaging checkpoints because not enough checkpoints is " "avaliable." ) return None avg_model_dir = os.path.join(model_dir, "avg_checkpoints") if not os.path.exists(avg_model_dir): utils.log( "Creating new directory {} for saving averaged checkpoints." .format( avg_model_dir)) os.makedirs(avg_model_dir) utils.log("Reading and averaging variables in checkpoints:") var_list = tf.contrib.framework.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if name != global_step_name: var_values[name] = np.zeros(shape) for checkpoint in checkpoints: utils.log("{}".format(checkpoint)) reader = tf.contrib.framework.load_checkpoint(checkpoint) for name in var_values: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] += tensor for name in var_values: var_values[name] /= len(checkpoints) # Build a graph with same variables in the checkpoints, and save the averaged # variables into the avg_model_dir. with tf.Graph().as_default(): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) for v in var_values ] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] global_step_var = tf.Variable( global_step, name=global_step_name, trainable=False) saver = tf.train.Saver(tf.all_variables()) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. Only keep 1 # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. saver.save( sess, os.path.join(avg_model_dir, "translate.ckpt")) return avg_model_dir