def _get_learning_rate_decay(self, hparams): """Get learning rate decay.""" if hparams.decay_scheme == "luong10": start_decay_step = int(hparams.num_train_steps / 2) remain_steps = hparams.num_train_steps - start_decay_step decay_steps = int(remain_steps / 10) # decay 10 times decay_factor = 0.5 elif hparams.decay_scheme == "luong234": start_decay_step = int(hparams.num_train_steps * 2 / 3) remain_steps = hparams.num_train_steps - start_decay_step decay_steps = int(remain_steps / 4) # decay 4 times decay_factor = 0.5 elif not hparams.decay_scheme: # no decay start_decay_step = hparams.num_train_steps decay_steps = 0 decay_factor = 1.0 elif hparams.decay_scheme: raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, " "decay_factor %g" % (hparams.decay_scheme, start_decay_step, decay_steps, decay_factor)) return tf.cond( self.global_step < start_decay_step, lambda: self.learning_rate, lambda: tf.train.exponential_decay( self.learning_rate, (self.global_step - start_decay_step), decay_steps, decay_factor, staircase=True), name="learning_rate_decay_cond")
def _cell_list(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, num_gpus, base_gpu=0, single_cell_fn=None, residual_fn=None): """Create a list of RNN cells.""" if not single_cell_fn: single_cell_fn = _single_cell # Multi-GPU cell_list = [] for i in range(num_layers): utils.print_out(" cell %d" % i, new_line=False) single_cell = single_cell_fn( unit_type=unit_type, num_units=num_units, forget_bias=forget_bias, dropout=dropout, mode=mode, residual_connection=(i >= num_layers - num_residual_layers), device_str=get_device_str(i + base_gpu, num_gpus), residual_fn=residual_fn) utils.print_out("") cell_list.append(single_cell) return cell_list
def load_model(model, ckpt, session, name): start_time = time.time() model.saver.restore(session, ckpt) session.run(tf.tables_initializer()) utils.print_out(" loaded %s model parameters from %s, time %.2fs" % (name, ckpt, time.time() - start_time)) return model
def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, eos=None, unk=None): """Check if vocab_file doesn't exist, create from corpus_file.""" if tf.gfile.Exists(vocab_file): utils.print_out("# Vocab file %s exists" % vocab_file) vocab, vocab_size = load_vocab(vocab_file) if check_special_token: # Verify if the vocab starts with unk, sos, eos # If not, prepend those tokens & generate a new vocab file if not unk: unk = UNK if not sos: sos = SOS if not eos: eos = EOS assert len(vocab) >= 3 if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: utils.print_out("The first 3 vocab words [%s, %s, %s]" " are not [%s, %s, %s]" % (vocab[0], vocab[1], vocab[2], unk, sos, eos)) vocab = [unk, sos, eos] + vocab vocab_size += 3 new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) with codecs.getwriter("utf-8")(tf.gfile.GFile( new_vocab_file, "wb")) as f: for word in vocab: f.write("%s\n" % word) vocab_file = new_vocab_file else: raise ValueError("vocab_file '%s' does not exist." % vocab_file) vocab_size = len(vocab) return vocab_size, vocab_file
def _build_encoder(self, hparams): """Build an encoder.""" num_layers = hparams.num_layers num_residual_layers = hparams.num_residual_layers iterator = self.iterator source = iterator.source if self.time_major: source = tf.transpose(source) with tf.variable_scope("encoder") as scope: dtype = scope.dtype # Look up embedding, emp_inp: [max_time, batch_size, num_units] encoder_emb_inp = tf.nn.embedding_lookup( self.embedding_encoder, source) # Encoder_outpus: [max_time, batch_size, num_units] if hparams.encoder_type == "uni": utils.print_out(" num_layers = %d, num_residual_layers=%d" % (num_layers, num_residual_layers)) cell = self._build_encoder_cell( hparams, num_layers, num_residual_layers) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell, encoder_emb_inp, dtype=dtype, sequence_length=iterator.source_sequence_length, time_major=self.time_major, swap_memory=True) elif hparams.encoder_type == "bi": num_bi_layers = int(num_layers / 2) num_bi_residual_layers = int(num_residual_layers / 2) utils.print_out(" num_bi_layers = %d, num_bi_residual_layers=%d" % (num_bi_layers, num_bi_residual_layers)) encoder_outputs, bi_encoder_state = ( self._build_bidirectional_rnn( inputs=encoder_emb_inp, sequence_length=iterator.source_sequence_length, dtype=dtype, hparams=hparams, num_bi_layers=num_bi_layers, num_bi_residual_layers=num_bi_residual_layers)) if num_bi_layers == 1: encoder_state = bi_encoder_state else: # alternatively concat forward and backward states encoder_state = [] for layer_id in range(num_bi_layers): encoder_state.append(bi_encoder_state[0][layer_id]) # forward encoder_state.append(bi_encoder_state[1][layer_id]) # backward encoder_state = tuple(encoder_state) else: raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) return encoder_outputs, encoder_state
def _get_infer_maximum_iterations(self, hparams, source_sequence_length): """Maximum decoding steps at inference time.""" if hparams.tgt_max_len_infer: maximum_iterations = hparams.tgt_max_len_infer utils.print_out(" decoding maximum_iterations %d" % maximum_iterations) else: # TODO(thangluong): add decoding_length_factor flag decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(source_sequence_length) maximum_iterations = tf.to_int32(tf.round( tf.to_float(max_encoder_length) * decoding_length_factor)) return maximum_iterations
def _sample_decode(model, global_step, sess, hparams, iterator, src_data, tgt_data, iterator_src_placeholder, iterator_batch_size_placeholder, summary_writer): """Pick a sentence and decode.""" decode_id = random.randint(0, len(src_data) - 1) utils.print_out(" # %d" % decode_id) iterator_feed_dict = { iterator_src_placeholder: [src_data[decode_id]], iterator_batch_size_placeholder: 1, } sess.run(iterator.initializer, feed_dict=iterator_feed_dict) nmt_outputs, attention_summary = model.decode(sess) if hparams.beam_width > 0: # get the top summarization. nmt_outputs = nmt_outputs[0] summarization = ars_utils.get_summarization( nmt_outputs, sent_id=0, tgt_eos=hparams.eos, subword_option=hparams.subword_option) utils.print_out(" src: %s" % src_data[decode_id]) utils.print_out(" ref: %s" % tgt_data[decode_id]) utils.print_out(b" ars: " + summarization) # Summary if attention_summary is not None: summary_writer.add_summary(attention_summary, global_step)
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model = load_model(model, latest_ckpt, session, name) else: start_time = time.time() session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) utils.print_out( " created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time)) global_step = model.global_step.eval(session=session) return model, global_step
def single_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Start decoding") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, subword_option=hparams.subword_option) else: ars_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_summarizations_per_input=hparams. num_summarizations_per_input)
def build_graph(self, hparams, scope=None): """Subclass must implement this method. Creates a sequence-to-sequence model with dynamic RNN decoder API. Args: hparams: Hyperparameter configurations. scope: VariableScope for the created subgraph; default "dynamic_seq2seq". Returns: A tuple of the form (logits, loss, final_context_state), where: logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: the total loss / batch_size. final_context_state: The final state of decoder RNN. Raises: ValueError: if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau). """ utils.print_out("# creating %s graph ..." % self.mode) dtype = tf.float32 num_layers = hparams.num_layers num_gpus = hparams.num_gpus with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype): # Encoder encoder_outputs, encoder_state = self._build_encoder(hparams) ## Decoder logits, sample_id, final_context_state = self._build_decoder( encoder_outputs, encoder_state, hparams) ## Loss if self.mode != tf.contrib.learn.ModeKeys.INFER: with tf.device(model_helper.get_device_str(num_layers - 1, num_gpus)): loss = self._compute_loss(logits) else: loss = None build_graph_res = collections.namedtuple('build_graph_res', ['logits', 'loss', 'final_context_state', 'sample_id']) return build_graph_res(logits, loss, final_context_state, sample_id)
def _decode_inference_indices(model, sess, output_infer, output_infer_summary_prefix, inference_indices, tgt_eos, subword_option): """Decoding only a specific set of sentences.""" utils.print_out(" decoding to output %s , num sents %d." % (output_infer, len(inference_indices))) start_time = time.time() with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer, mode="wb")) as trans_f: trans_f.write("") # Write empty string to ensure file is created. for decode_id in inference_indices: nmt_outputs, infer_summary = model.decode(sess) # get text summarization assert nmt_outputs.shape[0] == 1 summarization = ars_utils.get_summarization( nmt_outputs, sent_id=0, tgt_eos=tgt_eos, subword_option=subword_option) if infer_summary is not None: # Attention models image_file = output_infer_summary_prefix + str( decode_id) + ".png" utils.print_out(" save attention image to %s*" % image_file) image_summ = tf.Summary() image_summ.ParseFromString(infer_summary) with tf.gfile.GFile(image_file, mode="w") as img_f: img_f.write(image_summ.value[0].image.encoded_image_string) trans_f.write("%s\n" % summarization) utils.print_out(summarization + b"\n") utils.print_time(" done", start_time)
def _create_pretrained_emb_from_txt(vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, scope=None): """Load pretrain embeding from embed_file, and return an embedding matrix. Args: embed_file: Path to a Glove formated embedding txt file. num_trainable_tokens: Make the first n tokens in the vocab file as trainable variables. Default is 3, which is "<unk>", "<s>" and "</s>". """ vocab, _ = vocab_utils.load_vocab(vocab_file) trainable_tokens = vocab[:num_trainable_tokens] utils.print_out('# Using pretrained embedding: %s.' % embed_file) utils.print_out(' with trainable tokens: ') emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) for token in trainable_tokens: utils.print_out(' %s' % token) if token not in emb_dict: emb_dict[token] = [0.0] * emb_size emb_mat = np.array([emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) emb_mat = tf.constant(emb_mat) emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: emb_mat_var = tf.get_variable("emb_mat_var", [num_trainable_tokens, emb_size]) return tf.concat([emb_mat_var, emb_mat_const], 0)
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): """Run main.""" # Job jobid = flags.jobid num_workers = flags.num_workers utils.print_out("# Job id %d" % jobid) # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) ## Train / Decode out_dir = flags.out_dir if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir) # Load hparams. hparams = create_or_load_hparams( out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid==0)) if flags.inference_input_file: # Inference indices hparams.inference_indices = None if flags.inference_list: (hparams.inference_indices) = ( [int(token) for token in flags.inference_list.split(",")]) # Inference trans_file = flags.inference_output_file ckpt = flags.ckpt if not ckpt: ckpt = tf.train.latest_checkpoint(out_dir) inference_fn(ckpt, flags.inference_input_file, trans_file, hparams, num_workers, jobid) else: # Train train_fn(hparams, target_session=target_session)
def ensure_compatible_hparams(hparams, default_hparams, hparams_path): """Make sure the loaded hparams is compatible with new changes.""" default_hparams = utils.maybe_parse_standard_hparams( default_hparams, hparams_path) # For compatible reason, if there are new fields in default_hparams, # we add them to the current hparams default_config = default_hparams.values() config = hparams.values() for key in default_config: if key not in config: hparams.add_hparam(key, default_config[key]) # Update all hparams' keys if override_loaded_hparams=True if default_hparams.override_loaded_hparams: for key in default_config: if getattr(hparams, key) != default_config[key]: utils.print_out("# Updating hparams.%s: %s -> %s" % (key, str(getattr(hparams, key)), str(default_config[key]))) setattr(hparams, key, default_config[key]) return hparams
def check_stats(stats, global_step, steps_per_stats, hparams, log_f): """Print statistics and also check for overflow.""" # Print statistics for the previous epoch. avg_step_time = stats["step_time"] / steps_per_stats avg_grad_norm = stats["grad_norm"] / steps_per_stats train_ppl = utils.safe_exp( stats["loss"] / stats["predict_count"]) speed = stats["total_count"] / (1000 * stats["step_time"]) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" % (global_step, stats["learning_rate"], avg_step_time, speed, train_ppl, avg_grad_norm, _get_best_results(hparams)), log_f) # Check for overflow is_overflow = False if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: utils.print_out(" step %d overflow, stop early" % global_step, log_f) is_overflow = True return is_overflow
def _get_learning_rate_warmup(self, hparams): """Get learning rate warmup.""" warmup_steps = hparams.warmup_steps warmup_scheme = hparams.warmup_scheme utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % (hparams.learning_rate, warmup_steps, warmup_scheme)) # Apply inverse decay if global steps less than warmup steps. # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) # When step < warmup_steps, # learing_rate *= warmup_factor ** (warmup_steps - step) if warmup_scheme == "t2t": # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = warmup_factor**( tf.to_float(warmup_steps - self.global_step)) else: raise ValueError("Unknown warmup scheme %s" % warmup_scheme) return tf.cond( self.global_step < hparams.warmup_steps, lambda: inv_decay * self.learning_rate, lambda: self.learning_rate, name="learning_rate_warump_cond")
def extend_hparams(hparams): """Extend training hparams.""" # Sanity checks if hparams.encoder_type == "bi" and hparams.num_layers % 2 != 0: raise ValueError("For bi, num_layers %d should be even" % hparams.num_layers) if (hparams.attention_architecture in ["gnmt"] and hparams.num_layers < 2): raise ValueError("For gnmt attention architecture, " "num_layers %d should be >= 2" % hparams.num_layers) if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: raise ValueError("subword option must be either spm, or bpe") # Flags utils.print_out("# hparams:") utils.print_out(" src=%s" % hparams.src) utils.print_out(" tgt=%s" % hparams.tgt) utils.print_out(" train_prefix=%s" % hparams.train_prefix) utils.print_out(" dev_prefix=%s" % hparams.dev_prefix) utils.print_out(" test_prefix=%s" % hparams.test_prefix) utils.print_out(" out_dir=%s" % hparams.out_dir) # Set num_residual_layers if hparams.residual and hparams.num_layers > 1: if hparams.encoder_type == "gnmt": # The first unidirectional layer (after the bi-directional layer) in # the GNMT encoder can't have residual connection due to the input is # the concatenation of fw_cell and bw_cell's outputs. num_residual_layers = hparams.num_layers - 2 else: num_residual_layers = hparams.num_layers - 1 else: num_residual_layers = 0 hparams.add_hparam("num_residual_layers", num_residual_layers) ## Vocab # Get vocab file names first if hparams.vocab_dir: vocab_file_path = os.path.join(hparams.vocab_dir, hparams.vocab_filename) else: raise ValueError("hparams.vocab_dir must be provided.") # Source vocab vocab_size, vocab_file = vocab_utils.check_vocab( vocab_file_path, hparams.out_dir, check_special_token=hparams.check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) hparams.add_hparam("vocab_size", vocab_size) hparams.add_hparam("vocab_file", vocab_file) # Target vocab if hparams.share_emb: if tf.gfile.Exists(hparams.src_embed_file): utils.print_out(" using source embeddings for target") hparams.tgt_embed_file = hparams.src_embed_file elif tf.gfile.Exists(hparams.tgt_embed_file): utils.print_out(" using target embeddings for source") hparams.src_embed_file = hparams.tgt_embed_file else: if not tf.gfile.Exists(hparams.src_embed_file): raise ValueError('source embedding file :%s not found'%hparams.src_embed_file) if not tf.gfile.Exists(hparams.tgt_embed_file): raise ValueError('target embedding file :%s not found'%hparams.tgt_embed_file) # Check out_dir if not tf.gfile.Exists(hparams.out_dir): utils.print_out("# Creating output directory %s ..." % hparams.out_dir) tf.gfile.MakeDirs(hparams.out_dir) # Evaluation for metric in hparams.metrics: hparams.add_hparam("best_" + metric, 0) # larger is better best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric) hparams.add_hparam("best_" + metric + "_dir", best_metric_dir) tf.gfile.MakeDirs(best_metric_dir) return hparams
def _single_cell(unit_type, num_units, forget_bias, dropout, mode, residual_connection=False, device_str=None, residual_fn=None): """Create an instance of a single RNN cell.""" # dropout (= 1 - keep_prob) is set to 0 during eval and infer dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 # Cell Type if unit_type == "lstm": utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) single_cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=forget_bias) elif unit_type == "gru": utils.print_out(" GRU", new_line=False) single_cell = tf.contrib.rnn.GRUCell(num_units) elif unit_type == "layer_norm_lstm": utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, new_line=False) single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( num_units, forget_bias=forget_bias, layer_norm=True) elif unit_type == "nas": utils.print_out(" NASCell", new_line=False) single_cell = tf.contrib.rnn.NASCell(num_units) else: raise ValueError("Unknown unit type %s!" % unit_type) # Dropout (= 1 - keep_prob) if dropout > 0.0: single_cell = tf.contrib.rnn.DropoutWrapper(cell=single_cell, input_keep_prob=(1.0 - dropout)) utils.print_out(" %s, dropout=%g " % (type(single_cell).__name__, dropout), new_line=False) # Residual if residual_connection: single_cell = tf.contrib.rnn.ResidualWrapper(single_cell, residual_fn=residual_fn) utils.print_out(" %s" % type(single_cell).__name__, new_line=False) # Device Wrapper if device_str: single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) utils.print_out(" %s, device=%s" % (type(single_cell).__name__, device_str), new_line=False) return single_cell
def create_emb_for_encoder_and_decoder(share_emb, vocab_size, src_embed_size, tgt_embed_size, dtype=tf.float32, num_partitions=0, vocab_file=None, src_embed_file=None, tgt_embed_file=None, scope=None): """Create embedding matrix for both encoder and decoder. Args: share_emb: A boolean. Whether to share embedding matrix for both encoder and decoder. vocab_size: An integer. The vocab size. src_embed_size: An integer. The embedding dimension for the encoder's embedding. tgt_embed_size: An integer. The embedding dimension for the decoder's embedding. dtype: dtype of the embedding matrix. Default to float32. num_partitions: number of partitions used for the embedding vars. scope: VariableScope for the created subgraph. Default to "embedding". Returns: embedding_encoder: Encoder's embedding matrix. embedding_decoder: Decoder's embedding matrix. Raises: ValueError: if use share_emb but source and target have different vocab size. """ if num_partitions <= 1: partitioner = None else: # Note: num_partitions > 1 is required for distributed training due to # embedding_lookup tries to colocate single partition-ed embedding variable # with lookup ops. This may cause embedding variables being placed on worker # jobs. partitioner = tf.fixed_size_partitioner(num_partitions) if (src_embed_file or tgt_embed_file) and partitioner: raise ValueError( "Cann't set num_partitions > 1 when using pretrained embedding") with tf.variable_scope(scope or "embeddings", dtype=dtype, partitioner=partitioner) as scope: # Share embedding if share_emb: utils.print_out("# Use the same source embeddings for target") embed_file = src_embed_file or tgt_embed_file if vocab_file and embed_file: if src_embed_size != tgt_embed_size: raise ValueError( "Share embedding but different src/tgt emb sizes" " %d vs. %d" % (src_embed_size, tgt_embed_size)) embedding = _create_pretrained_emb_from_txt( vocab_file, embed_file) else: embedding = tf.get_variable("embedding_share", [vocab_size, src_embed_size], dtype) embedding_encoder = embedding embedding_decoder = embedding else: with tf.variable_scope("encoder", partitioner=partitioner): if vocab_file and src_embed_file: embedding_encoder = _create_pretrained_emb_from_txt( vocab_file, src_embed_file) else: embedding_encoder = tf.get_variable( "embedding_encoder", [vocab_size, src_embed_size], dtype) with tf.variable_scope("decoder", partitioner=partitioner): if vocab_file and tgt_embed_file: embedding_decoder = _create_pretrained_emb_from_txt( vocab_file, tgt_embed_file) else: embedding_decoder = tf.get_variable( "embedding_decoder", [vocab_size, tgt_embed_size], dtype) return embedding_encoder, embedding_decoder
def multi_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams, num_workers, jobid): """Inference using multiple workers.""" assert num_workers > 1 final_output_infer = inference_output_file output_infer = "%s_%d" % (inference_output_file, jobid) output_infer_done = "%s_done_%d" % (inference_output_file, jobid) # Read data infer_data = load_data(inference_input_file, hparams) # Split data to multiple workers total_load = len(infer_data) load_per_worker = int((total_load - 1) / num_workers) + 1 start_position = jobid * load_per_worker end_position = min(start_position + load_per_worker, total_load) infer_data = infer_data[start_position:end_position] with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run( infer_model.iterator.initializer, { infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Start decoding") ars_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_summarizations_per_input=hparams.num_summarizations_per_input) # Change file name to indicate the file writing is completed. tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) # Job 0 is responsible for the clean up. if jobid != 0: return # Now write all summarizations with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) while not tf.gfile.Exists(worker_infer_done): utils.print_out(" waitting job %d to complete." % worker_id) time.sleep(10) with codecs.getreader("utf-8")(tf.gfile.GFile( worker_infer_done, mode="rb")) as f: for summarization in f: final_f.write("%s" % summarization) for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) tf.gfile.Remove(worker_infer_done)
def train(hparams, scope=None, target_session=""): """Train a summarization model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = check_stats(stats, global_step, steps_per_stats, hparams, log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "summary.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval( eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "summary.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Done training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "summary.ckpt"), global_step=global_step) utils.print_time("# Done training!", start_train_time)
def __init__(self, hparams, mode, iterator, vocab_table, reverse_vocab_table=None, scope=None, extra_args=None): """Create the model. Args: hparams: Hyperparameter configurations. mode: TRAIN | EVAL | INFER iterator: Dataset Iterator that feeds data. vocab_table: Lookup table mapping words to ids. reverse_vocab_table: Lookup table mapping ids to target words. Only required in INFER mode. Defaults to None. scope: scope of the model. extra_args: model_helper.ExtraArgs, for passing customizable functions. """ assert isinstance(iterator, iterator_utils.BatchedInput) self.iterator = iterator self.mode = mode self.vocab_table = vocab_table self.vocab_size = hparams.vocab_size self.num_layers = hparams.num_layers self.num_gpus = hparams.num_gpus self.time_major = hparams.time_major # extra_args: to make it flexible for adding external customizable code self.single_cell_fn = None if extra_args: self.single_cell_fn = extra_args.single_cell_fn # Initializer initializer = model_helper.get_initializer( hparams.init_op, hparams.random_seed, hparams.init_weight) tf.get_variable_scope().set_initializer(initializer) # Embeddings self.init_embeddings(hparams, scope) self.batch_size = tf.size(self.iterator.source_sequence_length) # Projection with tf.variable_scope(scope or "build_network"): with tf.variable_scope("decoder/output_projection"): self.output_layer = layers_core.Dense( hparams.vocab_size, use_bias=False, name="output_projection") ## Train graph res = self.build_graph(hparams, scope=scope) self.sample_id = res.sample_id if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.train_loss = res.loss self.word_count = tf.reduce_sum( self.iterator.source_sequence_length) + tf.reduce_sum( self.iterator.target_sequence_length) elif self.mode == tf.contrib.learn.ModeKeys.EVAL: self.eval_loss = res.loss elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_logits, _, self.final_context_state, self.sample_id = res self.sample_words = reverse_vocab_table.lookup( tf.to_int64(self.sample_id)) if self.mode != tf.contrib.learn.ModeKeys.INFER: ## Count the number of predicted words for compute ppl. self.predict_count = tf.reduce_sum( self.iterator.target_sequence_length) self.global_step = tf.Variable(0, trainable=False) params = tf.trainable_variables() # Gradients and SGD update operation for training the model. # Arrage for the embedding vars to appear at the beginning. if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.learning_rate = tf.constant(hparams.learning_rate) # warm-up self.learning_rate = self._get_learning_rate_warmup(hparams) # decay self.learning_rate = self._get_learning_rate_decay(hparams) # Optimizer if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) tf.summary.scalar("lr", self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) # Gradients gradients = tf.gradients( self.train_loss, params, colocate_gradients_with_ops=hparams.colocate_gradients_with_ops) clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip( gradients, max_gradient_norm=hparams.max_gradient_norm) self.grad_norm = grad_norm self.update = opt.apply_gradients( zip(clipped_grads, params), global_step=self.global_step) # Summary self.train_summary = tf.summary.merge([ tf.summary.scalar("lr", self.learning_rate), tf.summary.scalar("train_loss", self.train_loss), ] + grad_norm_summary) if self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_summary = self._get_infer_summary(hparams) # Saver self.saver = tf.train.Saver( tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) # Print trainable variables utils.print_out("# Trainable variables") for param in params: utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), param.op.device))