def before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f): """Misc tasks to do before training.""" stats = init_stats() info = { "train_ppl": 0.0, "speed": 0.0, "avg_step_time": 0.0, "avg_grad_norm": 0.0, "avg_sequence_count": 0.0, "learning_rate": loaded_train_model.learning_rate.eval(session=train_sess) } start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, info["learning_rate"], time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) return stats, info, start_train_time
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""): """Make sure the loaded hparams is compatible with new changes.""" default_hparams = utils.maybe_parse_standard_hparams( default_hparams, hparams_path) # Set num encoder/decoder layers (for old checkpoints) if hasattr(hparams, "num_layers"): if not hasattr(hparams, "num_encoder_layers"): hparams.add_hparam("num_encoder_layers", hparams.num_layers) if not hasattr(hparams, "num_decoder_layers"): hparams.add_hparam("num_decoder_layers", hparams.num_layers) # For compatible reason, if there are new fields in default_hparams, # we add them to the current hparams default_config = default_hparams.values() config = hparams.values() for key in default_config: if key not in config: hparams.add_hparam(key, default_config[key]) # Update all hparams' keys if override_loaded_hparams=True if getattr(default_hparams, "override_loaded_hparams", None): overwritten_keys = default_config.keys() else: # For inference overwritten_keys = INFERENCE_KEYS for key in overwritten_keys: if getattr(hparams, key) != default_config[key]: utils.print_out( "# Updating hparams.%s: %s -> %s" % (key, str(getattr(hparams, key)), str(default_config[key]))) setattr(hparams, key, default_config[key]) return hparams
def _cell_list(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, num_gpus, base_gpu=0, single_cell_fn=None, residual_fn=None): """Create a list of RNN cells.""" if not single_cell_fn: single_cell_fn = _single_cell # Multi-GPU cell_list = [] for i in range(num_layers): utils.print_out(" cell %d" % i, new_line=False) single_cell = single_cell_fn( unit_type=unit_type, num_units=num_units, forget_bias=forget_bias, dropout=dropout, mode=mode, residual_connection=(i >= num_layers - num_residual_layers), device_str=get_device_str(i + base_gpu, num_gpus), residual_fn=residual_fn ) utils.print_out("") cell_list.append(single_cell) return cell_list
def print_variables_in_ckpt(ckpt_path): """Print a list of variables in a checkpoint together with their shapes.""" utils.print_out("# Variables in ckpt %s" % ckpt_path) reader = tf.train.NewCheckpointReader(ckpt_path) variable_map = reader.get_variable_to_shape_map() for key in sorted(variable_map.keys()): utils.print_out(" %s: %s" % (key, variable_map[key]))
def print_step_info(prefix, global_step, info, result_summary, log_f): """Print all info at the current global step.""" utils.print_out( "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" % (prefix, global_step, info["learning_rate"], info["avg_step_time"], info["speed"], info["train_ppl"], info["avg_grad_norm"], result_summary, time.ctime()), log_f)
def quantize_checkpoint(session, ckpt_path): """Quantize current loaded model and saves checkpoint in ckpt_path""" save_list = [tsr for tsr in tf.global_variables() if tsr not in tf.trainable_variables()] saver = tf.train.Saver(save_list) session.run(tf.variables_initializer(tf.get_collection(_QUANTIZATION_COLLECTION))) saver.save(session, ckpt_path) utils.print_out('Saved quantized checkpoint as %s' % ckpt_path)
def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, eos=None, unk=None): """Check if vocab_file doesn't exist, create from corpus_file.""" if tf.gfile.Exists(vocab_file): utils.print_out("# Vocab file %s exists" % vocab_file) vocab, vocab_size = load_vocab(vocab_file) if check_special_token: # Verify if the vocab starts with unk, sos, eos # If not, prepend those tokens & generate a new vocab file if not unk: unk = UNK if not sos: sos = SOS if not eos: eos = EOS assert len(vocab) >= 3 if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: utils.print_out("The first 3 vocab words [%s, %s, %s]" " are not [%s, %s, %s]" % (vocab[0], vocab[1], vocab[2], unk, sos, eos)) vocab = [unk, sos, eos] + vocab vocab_size += 3 new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) with codecs.getwriter("utf-8")( tf.gfile.GFile(new_vocab_file, "wb")) as f: for word in vocab: f.write("%s\n" % word) vocab_file = new_vocab_file else: raise ValueError("vocab_file '%s' does not exist." % vocab_file) vocab_size = len(vocab) return vocab_size, vocab_file
def _external_eval(model, global_step, sess, hparams, iterator, iterator_feed_dict, tgt_file, label, summary_writer, save_on_best, avg_ckpts=False): """External evaluation such as BLEU and ROUGE scores.""" out_dir = hparams.out_dir decode = global_step > 0 if avg_ckpts: label = "avg_" + label if decode: utils.print_out("# External evaluation, global step %d" % global_step) sess.run(iterator.initializer, feed_dict=iterator_feed_dict) output = os.path.join(out_dir, "output_%s" % label) scores = nmt_utils.decode_and_evaluate( label, model, sess, output, ref_file=tgt_file, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, decode=decode, infer_mode=hparams.infer_mode) # Save on best metrics if decode: for metric in hparams.metrics: if avg_ckpts: best_metric_label = "avg_best_" + metric else: best_metric_label = "best_" + metric utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), scores[metric]) # metric: larger is better if save_on_best and scores[metric] > getattr( hparams, best_metric_label): setattr(hparams, best_metric_label, scores[metric]) model.saver.save(sess, os.path.join( getattr(hparams, best_metric_label + "_dir"), "translate.ckpt"), global_step=model.global_step) utils.save_hparams(out_dir, hparams) return scores
def _get_infer_maximum_iterations(self, hparams, source_sequence_length): """Maximum decoding steps at inference time.""" if hparams.tgt_max_len_infer: maximum_iterations = hparams.tgt_max_len_infer utils.print_out(" decoding maximum_iterations %d" % maximum_iterations) else: # TODO(thangluong): add decoding_length_factor flag decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(source_sequence_length) maximum_iterations = tf.to_int32( tf.round( tf.to_float(max_encoder_length) * decoding_length_factor)) return maximum_iterations
def _sample_decode(model, global_step, sess, hparams, iterator, src_data, tgt_data, iterator_src_placeholder, iterator_batch_size_placeholder, summary_writer): """Pick a sentence and decode.""" decode_id = random.randint(0, len(src_data) - 1) utils.print_out(" # %d" % decode_id) iterator_feed_dict = { iterator_src_placeholder: [src_data[decode_id]], iterator_batch_size_placeholder: 1, } sess.run(iterator.initializer, feed_dict=iterator_feed_dict) nmt_outputs, attention_summary = model.decode(sess) if hparams.infer_mode == "beam_search": # get the top translation. nmt_outputs = nmt_outputs[0] translation = nmt_utils.get_translation( nmt_outputs, sent_id=0, tgt_eos=hparams.eos, subword_option=hparams.subword_option) utils.print_out(" src: %s" % src_data[decode_id]) utils.print_out(" ref: %s" % tgt_data[decode_id]) utils.print_out(b" nmt: " + translation) # Summary if attention_summary is not None: summary_writer.add_summary(attention_summary, global_step)
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model = load_model(model, latest_ckpt, session, name) else: start_time = time.time() session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) utils.print_out(" created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time)) global_step = model.global_step.eval(session=session) return model, global_step
def _get_learning_rate_decay(self, hparams): """Get learning rate decay.""" start_decay_step, decay_steps, decay_factor = self._get_decay_info( hparams) utils.print_out( " decay_scheme=%s, start_decay_step=%d, decay_steps %d, " "decay_factor %g" % (hparams.decay_scheme, start_decay_step, decay_steps, decay_factor)) return tf.cond(self.global_step < start_decay_step, lambda: self.learning_rate, lambda: tf.train.exponential_decay(self.learning_rate, ( self.global_step - start_decay_step), decay_steps, decay_factor, staircase=True), name="learning_rate_decay_cond")
def process_stats(stats, info, global_step, steps_per_stats, log_f): """Update info and check for overflow.""" # Per-step info info["avg_step_time"] = stats["step_time"] / steps_per_stats info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats info["speed"] = stats["word_count"] / (1000 * stats["step_time"]) # Per-predict info info["train_ppl"] = (utils.safe_exp(stats["train_loss"] / stats["predict_count"])) # Check for overflow is_overflow = False train_ppl = info["train_ppl"] if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: utils.print_out(" step %d overflow, stop early" % global_step, log_f) is_overflow = True return is_overflow
def load_embed_txt(embed_file): """Load embed_file into a python dictionary. Note: the embed_file should be a Glove/word2vec formatted txt file. Assuming Here is an exampe assuming embed_size=5: the -0.071549 0.093459 0.023738 -0.090339 0.056123 to 0.57346 0.5417 -0.23477 -0.3624 0.4037 and 0.20327 0.47348 0.050877 0.002103 0.060547 For word2vec format, the first line will be: <num_words> <emb_size>. Args: embed_file: file path to the embedding file. Returns: a dictionary that maps word to vector, and the size of embedding dimensions. """ emb_dict = dict() emb_size = None is_first_line = True with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, "rb")) as f: for line in f: tokens = line.rstrip().split(" ") if is_first_line: is_first_line = False if len(tokens) == 2: # header line emb_size = int(tokens[1]) continue word = tokens[0] vec = list(map(float, tokens[1:])) emb_dict[word] = vec if emb_size: if emb_size != len(vec): utils.print_out( "Ignoring %s since embeding size is inconsistent." % word) del emb_dict[word] else: emb_size = len(vec) return emb_dict, emb_size
def load_quantized_model(model, ckpt_path, session, name): """Loads quantized model and dequantizes variables""" start_time = time.time() dequant_ops = [] for tsr in tf.trainable_variables(): with tf.variable_scope(tsr.name.split(':')[0], reuse=True): quant_tsr = tf.get_variable('quantized', dtype=tf.qint8) min_range = tf.get_variable('min_range') max_range = tf.get_variable('max_range') dequant_ops.append( tsr.assign(tf.dequantize(quant_tsr, min_range, max_range, 'SCALED'))) restore_list = [tsr for tsr in tf.global_variables() if tsr not in tf.trainable_variables()] saver = tf.train.Saver(restore_list) try: saver.restore(session, ckpt_path) except tf.errors.NotFoundError as e: utils.print_out("Can't load checkpoint") print_variables_in_ckpt(ckpt_path) utils.print_out("%s" % str(e)) session.run(tf.tables_initializer()) session.run(dequant_ops) utils.print_out( " loaded %s model parameters from %s, time %.2fs" % (name, ckpt_path, time.time() - start_time)) return model
def _decode_inference_indices(model, sess, output_infer, output_infer_summary_prefix, inference_indices, tgt_eos, subword_option): """Decoding only a specific set of sentences.""" utils.print_out(" decoding to output %s , num sents %d." % (output_infer, len(inference_indices))) start_time = time.time() with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer, mode="wb")) as trans_f: trans_f.write("") # Write empty string to ensure file is created. for decode_id in inference_indices: nmt_outputs, infer_summary = model.decode(sess) # get text translation assert nmt_outputs.shape[0] == 1 translation = nmt_utils.get_translation( nmt_outputs, sent_id=0, tgt_eos=tgt_eos, subword_option=subword_option) if infer_summary is not None: # Attention models image_file = output_infer_summary_prefix + str( decode_id) + ".png" utils.print_out(" save attention image to %s*" % image_file) image_summ = tf.Summary() image_summ.ParseFromString(infer_summary) with tf.gfile.GFile(image_file, mode="w") as img_f: # pylint: disable=no-member img_f.write(image_summ.value[0].image.encoded_image_string) trans_f.write("%s\n" % translation) utils.print_out(translation + b"\n") utils.print_time(" done", start_time)
def single_worker_inference(sess, infer_model, loaded_infer_model, inference_input_file, inference_output_file, hparams): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) with infer_model.graph.as_default(): sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Start decoding") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, subword_option=hparams.subword_option) else: nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input, infer_mode=hparams.infer_mode)
def _create_pretrained_emb_from_txt( vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, scope=None): """Load pretrain embeding from embed_file, and return an embedding matrix. Args: embed_file: Path to a Glove formated embedding txt file. num_trainable_tokens: Make the first n tokens in the vocab file as trainable variables. Default is 3, which is "<unk>", "<s>" and "</s>". """ vocab, _ = vocab_utils.load_vocab(vocab_file) trainable_tokens = vocab[:num_trainable_tokens] utils.print_out("# Using pretrained embedding: %s." % embed_file) utils.print_out(" with trainable tokens: ") emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) for token in trainable_tokens: utils.print_out(" %s" % token) if token not in emb_dict: emb_dict[token] = [0.0] * emb_size emb_mat = np.array( [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) emb_mat = tf.constant(emb_mat) emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: with tf.device(_get_embed_device(num_trainable_tokens)): emb_mat_var = tf.get_variable( "emb_mat_var", [num_trainable_tokens, emb_size]) return tf.concat([emb_mat_var, emb_mat_const], 0)
def _get_learning_rate_warmup(self, hparams): """Get learning rate warmup.""" warmup_steps = hparams.warmup_steps warmup_scheme = hparams.warmup_scheme utils.print_out( " learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % (hparams.learning_rate, warmup_steps, warmup_scheme)) # Apply inverse decay if global steps less than warmup steps. # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) # When step < warmup_steps, # learing_rate *= warmup_factor ** (warmup_steps - step) if warmup_scheme == "t2t": # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step)) else: raise ValueError("Unknown warmup scheme %s" % warmup_scheme) return tf.cond(self.global_step < hparams.warmup_steps, lambda: inv_decay * self.learning_rate, lambda: self.learning_rate, name="learning_rate_warump_cond")
def load_model(model, ckpt_path, session, name): """Load model from a checkpoint.""" start_time = time.time() try: model.saver.restore(session, ckpt_path) except tf.errors.NotFoundError as e: utils.print_out("Can't load checkpoint") print_variables_in_ckpt(ckpt_path) utils.print_out("%s" % str(e)) session.run(tf.tables_initializer()) utils.print_out(" loaded %s model parameters from %s, time %.2fs" % (name, ckpt_path, time.time() - start_time)) return model
def _build_encoder(self, hparams): """Build a GNMT encoder.""" if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": return super(GNMTModel, self)._build_encoder(hparams) if hparams.encoder_type != "gnmt": raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) # Build GNMT encoder. num_bi_layers = 1 num_uni_layers = self.num_encoder_layers - num_bi_layers utils.print_out("# Build a GNMT encoder") utils.print_out(" num_bi_layers = %d" % num_bi_layers) utils.print_out(" num_uni_layers = %d" % num_uni_layers) iterator = self.iterator source = iterator.source if self.time_major: source = tf.transpose(source) with tf.variable_scope("encoder") as scope: dtype = scope.dtype self.encoder_emb_inp = self.encoder_emb_lookup_fn( self.embedding_encoder, source) # Execute _build_bidirectional_rnn from Model class bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( inputs=self.encoder_emb_inp, sequence_length=iterator.source_sequence_length, dtype=dtype, hparams=hparams, num_bi_layers=num_bi_layers, num_bi_residual_layers=0, # no residual connection ) # Build unidirectional layers if self.extract_encoder_layers: encoder_state, encoder_outputs = self._build_individual_encoder_layers( bi_encoder_outputs, num_uni_layers, dtype, hparams) else: encoder_state, encoder_outputs = self._build_all_encoder_layers( bi_encoder_outputs, num_uni_layers, dtype, hparams) # Pass all encoder states to the decoder # except the first bi-directional layer encoder_state = (bi_encoder_state[1], ) + ( (encoder_state, ) if num_uni_layers == 1 else encoder_state) return encoder_outputs, encoder_state
def avg_checkpoints(model_dir, num_last_checkpoints, global_step, global_step_name): """Average the last N checkpoints in the model_dir.""" checkpoint_state = tf.train.get_checkpoint_state(model_dir) if not checkpoint_state: utils.print_out("# No checkpoint file found in directory: %s" % model_dir) return None # Checkpoints are ordered from oldest to newest. checkpoints = ( checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) if len(checkpoints) < num_last_checkpoints: utils.print_out( "# Skipping averaging checkpoints because not enough checkpoints is " "avaliable." ) return None avg_model_dir = os.path.join(model_dir, "avg_checkpoints") if not tf.gfile.Exists(avg_model_dir): utils.print_out( "# Creating new directory %s for saving averaged checkpoints." % avg_model_dir) tf.gfile.MakeDirs(avg_model_dir) utils.print_out("# Reading and averaging variables in checkpoints:") var_list = tf.contrib.framework.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if name != global_step_name: var_values[name] = np.zeros(shape) for checkpoint in checkpoints: utils.print_out(" %s" % checkpoint) reader = tf.contrib.framework.load_checkpoint(checkpoint) for name in var_values: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] += tensor for name in var_values: var_values[name] /= len(checkpoints) # Build a graph with same variables in the checkpoints, and save the averaged # variables into the avg_model_dir. with tf.Graph().as_default(): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) for v in var_values] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] tf.Variable(global_step, name=global_step_name, trainable=False) saver = tf.train.Saver(tf.all_variables()) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. Only keep 1 # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. saver.save( sess, os.path.join(avg_model_dir, "translate.ckpt")) return avg_model_dir
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): """Run main.""" # Job jobid = flags.jobid num_workers = flags.num_workers utils.print_out("# Job id %d" % jobid) # GPU device utils.print_out("# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices())) # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) # Model output directory out_dir = flags.out_dir if out_dir and not tf.gfile.Exists(out_dir): utils.print_out("# Creating output directory %s ..." % out_dir) tf.gfile.MakeDirs(out_dir) # Load hparams. loaded_hparams = False if flags.ckpt: # Try to load hparams from the same directory as ckpt ckpt_dir = os.path.dirname(flags.ckpt) ckpt_hparams_file = os.path.join(ckpt_dir, "hparams") if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path: hparams = create_or_load_hparams(ckpt_dir, default_hparams, flags.hparams_path, save_hparams=False) loaded_hparams = True if not loaded_hparams: # Try to load from out_dir assert out_dir hparams = create_or_load_hparams(out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0)) # Train / Decode if flags.inference_input_file: # Inference output directory trans_file = flags.inference_output_file assert trans_file trans_dir = os.path.dirname(trans_file) if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) # Inference indices hparams.inference_indices = None if flags.inference_list: (hparams.inference_indices) = ([ int(token) for token in flags.inference_list.split(",") ]) # Inference ckpt = flags.ckpt if not ckpt: ckpt = tf.train.latest_checkpoint(out_dir) inference_fn(ckpt, flags.inference_input_file, trans_file, hparams, num_workers, jobid) # Evaluation ref_file = flags.inference_ref_file if ref_file and tf.gfile.Exists(trans_file): for metric in hparams.metrics: score = evaluation_utils.evaluate(ref_file, trans_file, metric, hparams.subword_option) utils.print_out(" %s: %.1f" % (metric, score)) else: # Train train_fn(hparams, target_session=target_session)
def _build_decoder(self, encoder_outputs, encoder_state, hparams): """Build and run a RNN decoder with a final projection layer. Args: encoder_outputs: The outputs of encoder for every time step. encoder_state: The final state of the encoder. hparams: The Hyperparameters configurations. Returns: A tuple of final logits and final decoder state: logits: size [time, batch_size, vocab_size] when time_major=True. """ tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) iterator = self.iterator # maximum_iteration: The maximum decoding steps. maximum_iterations = self._get_infer_maximum_iterations( hparams, iterator.source_sequence_length) # Decoder. with tf.variable_scope("decoder") as decoder_scope: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, iterator.source_sequence_length) # Optional ops depends on which mode we are in and which loss function we # are using. logits = tf.no_op() decoder_cell_outputs = None # Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: # decoder_emp_inp: [max_time, batch_size, num_units] target_input = iterator.target_input if self.time_major: target_input = tf.transpose(target_input) decoder_emb_inp = tf.nn.embedding_lookup( self.embedding_decoder, target_input) # Helper helper = tf.contrib.seq2seq.TrainingHelper( decoder_emb_inp, iterator.target_sequence_length, time_major=self.time_major) # Decoder my_decoder = tf.contrib.seq2seq.BasicDecoder( cell, helper, decoder_initial_state, ) # Dynamic decoding outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) sample_id = outputs.sample_id if self.num_sampled_softmax > 0: # Note: this is required when using sampled_softmax_loss. decoder_cell_outputs = outputs.rnn_output # Note: there's a subtle difference here between train and inference. # We could have set output_layer when create my_decoder # and shared more code between train and inference. # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per timestep. num_layers = self.num_decoder_layers num_gpus = self.num_gpus device_id = num_layers if num_layers < num_gpus else ( num_layers - 1) # Colocate output layer with the last RNN cell if there is no extra GPU # available. Otherwise, put last layer on a separate GPU. with tf.device(model_helper.get_device_str( device_id, num_gpus)): logits = self.output_layer(outputs.rnn_output) if self.num_sampled_softmax > 0: logits = tf.no_op( ) # unused when using sampled softmax loss. # Inference else: infer_mode = hparams.infer_mode start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id utils.print_out( " decoder: infer_mode=%sbeam_width=%d, length_penalty=%f" % (infer_mode, hparams.beam_width, hparams.length_penalty_weight)) if infer_mode == "beam_search": beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=cell, embedding=self.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=length_penalty_weight) elif infer_mode == "sample": # Helper sampling_temperature = hparams.sampling_temperature assert sampling_temperature > 0.0, ( "sampling_temperature must greater than 0.0 when using sample" " decoder.") helper = tf.contrib.seq2seq.SampleEmbeddingHelper( self.embedding_decoder, start_tokens, end_token, softmax_temperature=sampling_temperature, seed=self.random_seed) elif infer_mode == "greedy": helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embedding_decoder, start_tokens, end_token) else: raise ValueError("Unknown infer_mode '%s'", infer_mode) if infer_mode != "beam_search": my_decoder = tf.contrib.seq2seq.BasicDecoder( cell, helper, decoder_initial_state, output_layer=self.output_layer # applied per timestep ) # Dynamic decoding outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, maximum_iterations=maximum_iterations, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) if infer_mode == "beam_search": sample_id = outputs.predicted_ids else: logits = outputs.rnn_output sample_id = outputs.sample_id return logits, decoder_cell_outputs, sample_id, final_context_state
def build_graph(self, hparams, scope=None): """Subclass must implement this method. Creates a sequence-to-sequence model with dynamic RNN decoder API. Args: hparams: Hyperparameter configurations. scope: VariableScope for the created subgraph; default "dynamic_seq2seq". Returns: A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where: logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: loss = the total loss / batch_size. final_context_state: the final state of decoder RNN. sample_id: sampling indices. Raises: ValueError: if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau). """ utils.print_out("# Creating %s graph ..." % self.mode) # Projection if not self.extract_encoder_layers: with tf.variable_scope(scope or "build_network"): with tf.variable_scope("decoder/output_projection"): if hparams.projection_type == 'sparse': self.output_layer = core_layers.MaskedFullyConnected( hparams.tgt_vocab_size, use_bias=False, name="output_projection") elif hparams.projection_type == 'dense': self.output_layer = tf.layers.Dense( hparams.tgt_vocab_size, use_bias=False, name="output_projection") else: raise ValueError("Unknown projection type %s!" % hparams.projection_type) with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): # Encoder if hparams.language_model: # no encoder for language modeling utils.print_out(" language modeling: no encoder") self.encoder_outputs = None encoder_state = None else: self.encoder_outputs, encoder_state = self._build_encoder( hparams) # Skip decoder if extracting only encoder layers if self.extract_encoder_layers: return # Decoder logits, decoder_cell_outputs, sample_id, final_context_state = ( self._build_decoder(self.encoder_outputs, encoder_state, hparams)) # Loss if self.mode != tf.contrib.learn.ModeKeys.INFER: with tf.device( model_helper.get_device_str( self.num_encoder_layers - 1, self.num_gpus)): loss = self._compute_loss(logits, decoder_cell_outputs) else: loss = tf.constant(0.0) # model pruning if hparams.pruning_hparams is not None: pruning_hparams = pruning.get_pruning_hparams().parse( hparams.pruning_hparams) self.p = pruning.Pruning(pruning_hparams, global_step=self.global_step) self.mask_update_op = self.p.conditional_mask_update_op() masks = get_masks() thresholds = get_thresholds() masks_s = [] for index, mask in enumerate(masks): masks_s.append( tf.summary.scalar(mask.name + '/sparsity', tf.nn.zero_fraction(mask))) masks_s.append( tf.summary.scalar( thresholds[index].op.name + '/threshold', thresholds[index])) masks_s.append( tf.summary.histogram(mask.name + '/mask_tensor', mask)) self.pruning_summary = tf.summary.merge([ tf.summary.scalar('sparsity', self.p._sparsity), tf.summary.scalar('last_mask_update_step', self.p._last_update_step) ] + masks_s) else: self.mask_update_op = tf.no_op() self.pruning_summary = tf.no_op() return logits, loss, final_context_state, sample_id
def _build_encoder_from_sequence(self, hparams, sequence, sequence_length): """Build an encoder from a sequence. Args: hparams: hyperparameters. sequence: tensor with input sequence data. sequence_length: tensor with length of the input sequence. Returns: encoder_outputs: RNN encoder outputs. encoder_state: RNN encoder state. Raises: ValueError: if encoder_type is neither "uni" nor "bi". """ num_layers = self.num_encoder_layers num_residual_layers = self.num_encoder_residual_layers if self.time_major: sequence = tf.transpose(sequence) with tf.variable_scope("encoder") as scope: dtype = scope.dtype self.encoder_emb_inp = self.encoder_emb_lookup_fn( self.embedding_encoder, sequence) # Encoder_outputs: [max_time, batch_size, num_units] if hparams.encoder_type == "uni": utils.print_out(" num_layers = %d, num_residual_layers=%d" % (num_layers, num_residual_layers)) cell = self._build_encoder_cell(hparams, num_layers, num_residual_layers) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell, self.encoder_emb_inp, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major, swap_memory=True) elif hparams.encoder_type == "bi": num_bi_layers = int(num_layers / 2) num_bi_residual_layers = int(num_residual_layers / 2) utils.print_out( " num_bi_layers = %d, num_bi_residual_layers=%d" % (num_bi_layers, num_bi_residual_layers)) encoder_outputs, bi_encoder_state = ( self._build_bidirectional_rnn( inputs=self.encoder_emb_inp, sequence_length=sequence_length, dtype=dtype, hparams=hparams, num_bi_layers=num_bi_layers, num_bi_residual_layers=num_bi_residual_layers)) if num_bi_layers == 1: encoder_state = bi_encoder_state else: # alternatively concat forward and backward states encoder_state = [] for layer_id in range(num_bi_layers): encoder_state.append( bi_encoder_state[0][layer_id]) # forward encoder_state.append( bi_encoder_state[1][layer_id]) # backward encoder_state = tuple(encoder_state) else: raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) # Use the top layer for now self.encoder_state_list = [encoder_outputs] return encoder_outputs, encoder_state
def _build_encoder(self, hparams): """Build encoder from source.""" utils.print_out("# Build a basic encoder") return self._build_encoder_from_sequence( hparams, self.iterator.source, self.iterator.source_sequence_length)
def _single_cell(unit_type, num_units, forget_bias, dropout, mode, residual_connection=False, device_str=None, residual_fn=None): """Create an instance of a single RNN cell.""" # dropout (= 1 - keep_prob) is set to 0 during eval and infer dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 # Cell Type if unit_type == "lstm": utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) single_cell = tf.contrib.rnn.BasicLSTMCell( num_units, forget_bias=forget_bias) elif unit_type == "gru": utils.print_out(" GRU", new_line=False) single_cell = tf.contrib.rnn.GRUCell(num_units) elif unit_type == "layer_norm_lstm": utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, new_line=False) single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( num_units, forget_bias=forget_bias, layer_norm=True) elif unit_type == "nas": utils.print_out(" NASCell", new_line=False) single_cell = tf.contrib.rnn.NASCell(num_units) elif unit_type == "mlstm": utils.print_out(" Masked_LSTM, forget_bias=%g" % forget_bias, new_line=False) single_cell = tf.contrib.model_pruning.MaskedBasicLSTMCell( num_units, forget_bias=forget_bias) else: raise ValueError("Unknown unit type %s!" % unit_type) # Dropout (= 1 - keep_prob) if dropout > 0.0: single_cell = tf.contrib.rnn.DropoutWrapper( cell=single_cell, input_keep_prob=(1.0 - dropout)) utils.print_out(" %s, dropout=%g " % (type(single_cell).__name__, dropout), new_line=False) # Residual if residual_connection: single_cell = tf.contrib.rnn.ResidualWrapper( single_cell, residual_fn=residual_fn) utils.print_out(" %s" % type(single_cell).__name__, new_line=False) # Device Wrapper if device_str: single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) utils.print_out(" %s, device=%s" % (type(single_cell).__name__, device_str), new_line=False) return single_cell
def create_emb_for_encoder_and_decoder(share_vocab, src_vocab_size, tgt_vocab_size, src_embed_size, tgt_embed_size, embed_type='dense', dtype=tf.float32, num_enc_partitions=0, num_dec_partitions=0, src_vocab_file=None, tgt_vocab_file=None, src_embed_file=None, tgt_embed_file=None, use_char_encode=False, scope=None): """Create embedding matrix for both encoder and decoder. Args: share_vocab: A boolean. Whether to share embedding matrix for both encoder and decoder. src_vocab_size: An integer. The source vocab size. tgt_vocab_size: An integer. The target vocab size. src_embed_size: An integer. The embedding dimension for the encoder's embedding. tgt_embed_size: An integer. The embedding dimension for the decoder's embedding. dtype: dtype of the embedding matrix. Default to float32. num_enc_partitions: number of partitions used for the encoder's embedding vars. num_dec_partitions: number of partitions used for the decoder's embedding vars. scope: VariableScope for the created subgraph. Default to "embedding". Returns: embedding_encoder: Encoder's embedding matrix. embedding_decoder: Decoder's embedding matrix. Raises: ValueError: if use share_vocab but source and target have different vocab size. """ if num_enc_partitions <= 1: enc_partitioner = None else: # Note: num_partitions > 1 is required for distributed training due to # embedding_lookup tries to colocate single partition-ed embedding variable # with lookup ops. This may cause embedding variables being placed on worker # jobs. enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions) if num_dec_partitions <= 1: dec_partitioner = None else: # Note: num_partitions > 1 is required for distributed training due to # embedding_lookup tries to colocate single partition-ed embedding variable # with lookup ops. This may cause embedding variables being placed on worker # jobs. dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions) if src_embed_file and enc_partitioner: raise ValueError( "Can't set num_enc_partitions > 1 when using pretrained encoder " "embedding") if tgt_embed_file and dec_partitioner: raise ValueError( "Can't set num_dec_partitions > 1 when using pretrained decdoer " "embedding") with tf.variable_scope(scope or "embeddings", dtype=dtype, partitioner=enc_partitioner): # Share embedding if share_vocab: if src_vocab_size != tgt_vocab_size: raise ValueError("Share embedding but different src/tgt vocab sizes" " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) assert src_embed_size == tgt_embed_size utils.print_out("# Use the same embedding for source and target") vocab_file = src_vocab_file or tgt_vocab_file embed_file = src_embed_file or tgt_embed_file embedding_encoder = _create_or_load_embed( "embedding_share", vocab_file, embed_file, src_vocab_size, src_embed_size, dtype, embed_type=embed_type) embedding_decoder = embedding_encoder else: if not use_char_encode: with tf.variable_scope("encoder", partitioner=enc_partitioner): embedding_encoder = _create_or_load_embed( "embedding_encoder", src_vocab_file, src_embed_file, src_vocab_size, src_embed_size, dtype, embed_type=embed_type) else: embedding_encoder = None with tf.variable_scope("decoder", partitioner=dec_partitioner): embedding_decoder = _create_or_load_embed( "embedding_decoder", tgt_vocab_file, tgt_embed_file, tgt_vocab_size, tgt_embed_size, dtype, embed_type=embed_type) return embedding_encoder, embedding_decoder
def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams): """Set up training and inference.""" if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.train_loss = res[1] self.word_count = tf.reduce_sum( self.iterator.source_sequence_length) + tf.reduce_sum( self.iterator.target_sequence_length) elif self.mode == tf.contrib.learn.ModeKeys.EVAL: self.eval_loss = res[1] elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_logits, _, self.final_context_state, self.sample_id = res self.sample_words = reverse_target_vocab_table.lookup( tf.to_int64(self.sample_id)) if self.mode != tf.contrib.learn.ModeKeys.INFER: # Count the number of predicted words for compute ppl. self.predict_count = tf.reduce_sum( self.iterator.target_sequence_length) params = tf.trainable_variables() # Gradients and SGD update operation for training the model. # Arrange for the embedding vars to appear at the beginning. if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.learning_rate = tf.constant(hparams.learning_rate) # warm-up self.learning_rate = self._get_learning_rate_warmup(hparams) # decay self.learning_rate = self._get_learning_rate_decay(hparams) # Optimizer if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) else: raise ValueError("Unknown optimizer type %s" % hparams.optimizer) # Gradients gradients = tf.gradients(self.train_loss, params, colocate_gradients_with_ops=hparams. colocate_gradients_with_ops) clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip( gradients, max_gradient_norm=hparams.max_gradient_norm) self.grad_norm_summary = grad_norm_summary self.grad_norm = grad_norm self.update = opt.apply_gradients(zip(clipped_grads, params), global_step=self.global_step) # Summary self.train_summary = self._get_train_summary() elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_summary = self._get_infer_summary(hparams) # Print trainable variables utils.print_out("# Trainable variables") utils.print_out("Format: <name>, <shape>, <(soft) device placement>") for param in params: utils.print_out( " %s, %s, %s" % (param.name, str(param.get_shape()), param.op.device))