def _decode_target( self, labels, encoder_outputs, encoder_state, encoder_sequence_length, step=None, training=None, ): params = self.params target_inputs = self.labels_inputter(labels, training=training) input_fn = lambda ids: self.labels_inputter({"ids": ids}, training=training) sampling_probability = None if training: sampling_probability = decoder_util.get_sampling_probability( step, read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k"), ) initial_state = self.decoder.initial_state( memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, initial_state=encoder_state, ) logits, _, attention = self.decoder( target_inputs, self.labels_inputter.get_length(labels), state=initial_state, input_fn=input_fn, sampling_probability=sampling_probability, training=training, ) outputs = dict(logits=logits, attention=attention) noisy_ids = labels.get("noisy_ids") if noisy_ids is not None and params.get("contrastive_learning"): # In case of contrastive learning, also forward the erroneous # translation to compute its log likelihood later. noisy_inputs = self.labels_inputter({"ids": noisy_ids}, training=training) noisy_logits, _, _ = self.decoder( noisy_inputs, labels["noisy_length"], state=initial_state, input_fn=input_fn, sampling_probability=sampling_probability, training=training, ) outputs["noisy_logits"] = noisy_logits return outputs
def _build(self, features, labels, params, mode, config=None): features_length = self.features_inputter.get_length(features) log_dir = config.model_dir if config is not None else None source_input_scope = self._get_input_scope(default_name="encoder") target_input_scope = self._get_input_scope(default_name="decoder") source_inputs = _maybe_reuse_embedding_fn( lambda ids: self.source_inputter.transform_data( ids, mode=mode, log_dir=log_dir), scope=source_input_scope)(features) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.target_inputter.vocabulary_size target_dtype = self.target_inputter.dtype target_embedding_fn = _maybe_reuse_embedding_fn( lambda ids: self.target_inputter.make_inputs( {"ids": ids}, training=mode == tf.estimator.ModeKeys.TRAIN), scope=target_input_scope) if labels is not None: target_inputs = _maybe_reuse_embedding_fn( lambda ids: self.target_inputter.transform_data( ids, mode=mode, log_dir=log_dir), scope=target_input_scope)(labels) with tf.variable_scope("decoder"): sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) logits, _, _, attention = self.decoder.decode( target_inputs, self.labels_inputter.get_length(labels), vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=target_embedding_fn, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) if "alignment" in labels: outputs = {"logits": logits, "attention": attention} else: outputs = logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape( tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) maximum_iterations = params.get("maximum_iterations", 250) minimum_length = params.get("minimum_decoding_length", 0) sample_from = params.get("sampling_topk", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID if beam_width <= 1: sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode( target_embedding_fn, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, maximum_iterations=maximum_iterations, minimum_length=minimum_length, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=sample_from) else: length_penalty = params.get("length_penalty", 0) sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( target_embedding_fn, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, beam_width=beam_width, length_penalty=length_penalty, maximum_iterations=maximum_iterations, minimum_length=minimum_length, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=sample_from)) target_vocab_rev = self.target_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup( tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.source_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch( source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) attention = tf.reshape( alignment, [-1, tf.shape(alignment)[2], tf.shape(alignment)[3]]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment else: predictions = None return outputs, predictions
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN features_length = self.features_inputter.get_length(features) source_inputs = self.features_inputter.make_inputs(features, training=training) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.labels_inputter.vocabulary_size target_dtype = self.labels_inputter.dtype if labels is not None: target_inputs = self.labels_inputter.make_inputs(labels, training=training) with tf.variable_scope("decoder"): sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) logits, _, _, attention = self.decoder.decode( target_inputs, self.labels_inputter.get_length(labels), vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self.labels_inputter.embedding, output_layer=self.output_layer, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) if "alignment" in labels: outputs = {"logits": logits, "attention": attention} else: outputs = logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape( tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=self.output_layer, beam_width=beam_width, length_penalty=params.get("length_penalty", 0), maximum_iterations=params.get("maximum_iterations", 250), minimum_length=params.get("minimum_decoding_length", 0), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=params.get("sampling_topk"), sample_temperature=params.get("sampling_temperature"))) target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup( tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch( source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) align_shape = shape_list(alignment) attention = tf.reshape(alignment, [ align_shape[0] * align_shape[1], align_shape[2], align_shape[3] ]) # We don't have attention for </s> but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment else: predictions = None return outputs, predictions
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN features_length = self.features_inputter.get_length(features) source_inputs = self.features_inputter.make_inputs(features, training=training) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.labels_inputter.vocabulary_size target_dtype = self.labels_inputter.dtype if labels is not None: sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get("scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) def _decode_inputs(inputs, length, reuse=None): with tf.variable_scope("decoder", reuse=reuse): return self.decoder.decode( inputs, length, vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self.labels_inputter.embedding, output_layer=self.output_layer, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) target_inputs = self.labels_inputter.make_inputs(labels, training=training) logits, _, _, attention = _decode_inputs(target_inputs, labels["length"]) if "alignment" in labels: outputs = { "logits": logits, "attention": attention } else: outputs = logits noisy_ids = labels.get("noisy_ids") if noisy_ids is not None and params.get("contrastive_learning"): # In case of contrastive learning, also forward the erroneous # translation to compute its log likelihood later. noisy_inputs = self.labels_inputter.make_inputs({"ids": noisy_ids}, training=training) noisy_logits = _decode_inputs(noisy_inputs, labels["noisy_length"], reuse=True)[0] if not isinstance(outputs, dict): outputs = dict(logits=outputs) outputs["noisy_logits"] = noisy_logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape(tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=self.output_layer, beam_width=beam_width, length_penalty=params.get("length_penalty", 0), maximum_iterations=params.get("maximum_iterations", 250), minimum_length=params.get("minimum_decoding_length", 0), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=params.get("sampling_topk"), sample_temperature=params.get("sampling_temperature"), coverage_penalty=params.get("coverage_penalty", 0))) target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError("replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError("replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch(source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) align_shape = shape_list(alignment) attention = tf.reshape( alignment, [align_shape[0] * align_shape[1], align_shape[2], align_shape[3]]) # We don't have attention for </s> but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1]) replaced_target_tokens = replace_unknown_target(target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) decoding_noise = params.get("decoding_noise") if decoding_noise: sampled_length -= 1 # Ignore </s> target_tokens, sampled_length = _add_noise( target_tokens, sampled_length, decoding_noise, params.get("decoding_subword_token", "■")) sampled_length += 1 alignment = None # Invalidate alignments. predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment num_hypotheses = params.get("num_hypotheses", 1) if num_hypotheses > 0: if num_hypotheses > beam_width: raise ValueError("n_best cannot be greater than beam_width") for key, value in six.iteritems(predictions): predictions[key] = value[:, :num_hypotheses] else: predictions = None return outputs, predictions
def _build(self, features, labels, params, mode, config): features_length = self._get_features_length(features) with tf.variable_scope("encoder"): source_inputs = self.source_inputter.transform_data( features, mode=mode, log_dir=config.model_dir) encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.target_inputter.vocabulary_size with tf.variable_scope("decoder") as decoder_scope: if labels is not None: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get("scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) target_inputs = self.target_inputter.transform_data( labels, mode=mode, log_dir=config.model_dir) logits, _, _ = self.decoder.decode( target_inputs, self._get_labels_length(labels), target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self._scoped_target_embedding_fn(mode, decoder_scope), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) else: logits = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope: batch_size = tf.shape(encoder_sequence_length)[0] beam_width = params.get("beam_width", 1) maximum_iterations = params.get("maximum_iterations", 250) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID if beam_width <= 1: sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode( self._scoped_target_embedding_fn(mode, decoder_scope), start_tokens, end_token, target_vocab_size, initial_state=encoder_state, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) else: length_penalty = params.get("length_penalty", 0) sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode_and_search( self._scoped_target_embedding_fn(mode, decoder_scope), start_tokens, end_token, target_vocab_size, initial_state=encoder_state, beam_width=beam_width, length_penalty=length_penalty, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file( self.target_inputter.vocabulary_file, vocab_size=target_vocab_size - self.target_inputter.num_oov_buckets, default_value=constants.UNKNOWN_TOKEN) predictions = { "tokens": target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)), "length": sampled_length, "log_probs": log_probs } else: predictions = None return logits, predictions
def testSamplingProbability(self): step = tf.constant(5, dtype=tf.int64) large_step = tf.constant(1000, dtype=tf.int64) self.assertIsNone(decoder.get_sampling_probability(step)) with self.assertRaises(ValueError): decoder.get_sampling_probability(step, schedule_type="linear") with self.assertRaises(ValueError): decoder.get_sampling_probability(step, schedule_type="linear", k=1) with self.assertRaises(TypeError): decoder.get_sampling_probability(step, schedule_type="foo", k=1) constant_sample_prob = decoder.get_sampling_probability( step, read_probability=0.9) linear_sample_prob = decoder.get_sampling_probability( step, read_probability=1.0, schedule_type="linear", k=0.1) linear_sample_prob_same = decoder.get_sampling_probability( step, read_probability=2.0, schedule_type="linear", k=0.1) linear_sample_prob_inf = decoder.get_sampling_probability( large_step, read_probability=1.0, schedule_type="linear", k=0.1) exp_sample_prob = decoder.get_sampling_probability( step, schedule_type="exponential", k=0.8) inv_sig_sample_prob = decoder.get_sampling_probability( step, schedule_type="inverse_sigmoid", k=1) self.assertAlmostEqual(0.1, constant_sample_prob) self.assertAlmostEqual(0.5, self.evaluate(linear_sample_prob)) self.assertAlmostEqual(0.5, self.evaluate(linear_sample_prob_same)) self.assertAlmostEqual(1.0, self.evaluate(linear_sample_prob_inf)) self.assertAlmostEqual(1.0 - pow(0.8, 5), self.evaluate(exp_sample_prob)) self.assertAlmostEqual(1.0 - (1.0 / (1.0 + math.exp(5.0 / 1.0))), self.evaluate(inv_sig_sample_prob))
def _build(self, features, labels, params, mode, config): features_length = self._get_features_length(features) with tf.variable_scope("encoder"): source_inputs = self.source_inputter.transform_data( features, mode=mode, log_dir=config.model_dir) encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.target_inputter.vocabulary_size target_dtype = self.target_inputter.dtype with tf.variable_scope("decoder") as decoder_scope: if labels is not None: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get("scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) target_inputs = self.target_inputter.transform_data( labels, mode=mode, log_dir=config.model_dir) logits, _, _ = self.decoder.decode( target_inputs, self._get_labels_length(labels), target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self._scoped_target_embedding_fn(mode, decoder_scope), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) else: logits = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope: batch_size = tf.shape(encoder_sequence_length)[0] beam_width = params.get("beam_width", 1) maximum_iterations = params.get("maximum_iterations", 250) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID if beam_width <= 1: sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode( self._scoped_target_embedding_fn(mode, decoder_scope), start_tokens, end_token, target_vocab_size, initial_state=encoder_state, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype) else: length_penalty = params.get("length_penalty", 0) sampled_ids, _, sampled_length, log_probs = self.decoder.dynamic_decode_and_search( self._scoped_target_embedding_fn(mode, decoder_scope), start_tokens, end_token, target_vocab_size, initial_state=encoder_state, beam_width=beam_width, length_penalty=length_penalty, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype) target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file( self.target_inputter.vocabulary_file, vocab_size=target_vocab_size - self.target_inputter.num_oov_buckets, default_value=constants.UNKNOWN_TOKEN) predictions = { "tokens": target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)), "length": sampled_length, "log_probs": log_probs } else: predictions = None return logits, predictions
def _build(self, features, labels, params, mode, config=None): features_length = self._get_features_length(features) log_dir = config.model_dir if config is not None else None with tf.variable_scope("encoder"): source_inputs = self.source_inputter.transform_data( features, mode=mode, log_dir=log_dir) encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.target_inputter.vocabulary_size target_dtype = self.target_inputter.dtype with tf.variable_scope("decoder") as decoder_scope: if labels is not None: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) target_inputs = self.target_inputter.transform_data( labels, mode=mode, log_dir=log_dir) logits, _, _ = self.decoder.decode( target_inputs, self._get_labels_length(labels), vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self._scoped_target_embedding_fn( mode, decoder_scope), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length) else: logits = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope: batch_size = tf.shape(encoder_sequence_length)[0] beam_width = params.get("beam_width", 1) maximum_iterations = params.get("maximum_iterations", 250) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID if beam_width <= 1: sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode( self._scoped_target_embedding_fn(mode, decoder_scope), start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True) else: length_penalty = params.get("length_penalty", 0) sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self._scoped_target_embedding_fn( mode, decoder_scope), start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, beam_width=beam_width, length_penalty=length_penalty, maximum_iterations=maximum_iterations, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True)) target_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file( self.target_inputter.vocabulary_file, vocab_size=target_vocab_size - self.target_inputter.num_oov_buckets, default_value=constants.UNKNOWN_TOKEN) target_tokens = target_vocab_rev.lookup( tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.source_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch( source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) attention = tf.reshape( alignment, [-1, tf.shape(alignment)[2], tf.shape(alignment)[3]]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment else: predictions = None return logits, predictions