def testAlignInTimeSame(self): a = [[[1], [0], [0]], [[1], [2], [3]]] length = 3 b = reducer.align_in_time(tf.constant(a, dtype=tf.float32), tf.constant(length)) self.assertEqual(1, b.shape[-1]) self.assertAllEqual(a, self.evaluate(b))
def testAlignInTimeSmaller(self): a = [[[1], [0], [0]], [[1], [2], [0]]] length = 2 b = [[[1], [0]], [[1], [2]]] c = reducer.align_in_time(tf.constant(a, dtype=tf.float32), tf.constant(length)) self.assertEqual(1, c.shape[-1]) self.assertAllEqual(b, self.evaluate(c))
def testAlignInTimeLarger(self): a = [[[1], [0], [0]], [[1], [2], [3]]] length = 4 b = [[[1], [0], [0], [0]], [[1], [2], [3], [0]]] c = reducer.align_in_time(tf.constant(a, dtype=tf.float32), tf.constant(length)) self.assertEqual(1, c.get_shape().as_list()[-1]) self.assertAllEqual(b, self.evaluate(c))
def testAlignInTimeSame(self): a = [[[1], [0], [0]], [[1], [2], [3]]] length = 3 b = reducer.align_in_time(tf.constant(a, dtype=tf.float32), tf.constant(length)) self.assertEqual(1, b.get_shape().as_list()[-1]) with self.test_session() as sess: self.assertAllEqual(a, sess.run(b))
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN features_length = self.features_inputter.get_length(features) source_inputs = self.features_inputter.make_inputs(features, training=training) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.labels_inputter.vocabulary_size target_dtype = self.labels_inputter.dtype if labels is not None: target_inputs = self.labels_inputter.make_inputs(labels, training=training) with tf.variable_scope("decoder"): sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) logits, _, _, attention = self.decoder.decode( target_inputs, self.labels_inputter.get_length(labels), vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self.labels_inputter.embedding, output_layer=self.output_layer, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) if "alignment" in labels: outputs = {"logits": logits, "attention": attention} else: outputs = logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape( tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=self.output_layer, beam_width=beam_width, length_penalty=params.get("length_penalty", 0), maximum_iterations=params.get("maximum_iterations", 250), minimum_length=params.get("minimum_decoding_length", 0), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=params.get("sampling_topk"), sample_temperature=params.get("sampling_temperature"))) target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup( tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch( source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) align_shape = shape_list(alignment) attention = tf.reshape(alignment, [ align_shape[0] * align_shape[1], align_shape[2], align_shape[3] ]) # We don't have attention for </s> but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment else: predictions = None return outputs, predictions
def _dynamic_decode(self, features, encoder_outputs, encoder_state, encoder_sequence_length): params = self.params batch_size = tf.shape(tf.nest.flatten(encoder_outputs)[0])[0] start_ids = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) beam_size = params.get("beam_width", 1) if beam_size > 1: # Tile encoder outputs to prepare for beam search. encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, beam_size) encoder_sequence_length = tfa.seq2seq.tile_batch( encoder_sequence_length, beam_size) if encoder_state is not None: encoder_state = tfa.seq2seq.tile_batch(encoder_state, beam_size) # Dynamically decodes from the encoder outputs. initial_state = self.decoder.initial_state( memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, initial_state=encoder_state) sampled_ids, sampled_length, log_probs, alignment, _ = self.decoder.dynamic_decode( self.labels_inputter, start_ids, initial_state=initial_state, decoding_strategy=decoding.DecodingStrategy.from_params(params), sampler=decoding.Sampler.from_params(params), maximum_iterations=params.get("maximum_decoding_length", 250), minimum_iterations=params.get("minimum_decoding_length", 0)) target_tokens = self.labels_inputter.ids_to_tokens.lookup( tf.cast(sampled_ids, tf.int64)) # Maybe replace unknown targets by the source tokens with the highest attention weight. if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_size > 1: source_tokens = tfa.seq2seq.tile_batch(source_tokens, beam_size) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) align_shape = shape_list(alignment) attention = tf.reshape(alignment, [ align_shape[0] * align_shape[1], align_shape[2], align_shape[3] ]) # We don't have attention for </s> but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) # Maybe add noise to the predictions. decoding_noise = params.get("decoding_noise") if decoding_noise: target_tokens, sampled_length = _add_noise( target_tokens, sampled_length, decoding_noise, params.get("decoding_subword_token", "■")) alignment = None # Invalidate alignments. predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment # Maybe restrict the number of returned hypotheses based on the user parameter. num_hypotheses = params.get("num_hypotheses", 1) if num_hypotheses > 0: if num_hypotheses > beam_size: raise ValueError("n_best cannot be greater than beam_width") for key, value in six.iteritems(predictions): predictions[key] = value[:, :num_hypotheses] return predictions
def decode(self, inputs, sequence_length, vocab_size=None, initial_state=None, sampling_probability=None, embedding=None, output_layer=None, mode=tf.estimator.ModeKeys.TRAIN, memory=None, memory_sequence_length=None, return_alignment_history=False): _ = memory _ = memory_sequence_length batch_size = tf.shape(inputs)[0] if (sampling_probability is not None and (tf.contrib.framework.is_tensor(sampling_probability) or sampling_probability > 0.0)): if embedding is None: raise ValueError( "embedding argument must be set when using scheduled sampling" ) tf.summary.scalar("sampling_probability", sampling_probability) helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( inputs, sequence_length, embedding, sampling_probability) fused_projection = False else: helper = tf.contrib.seq2seq.TrainingHelper(inputs, sequence_length) fused_projection = True # With TrainingHelper, project all timesteps at once. cell, initial_state = self._build_cell( mode, batch_size, initial_state=initial_state, memory=memory, memory_sequence_length=memory_sequence_length, dtype=inputs.dtype, alignment_history=return_alignment_history) if output_layer is None: output_layer = build_output_layer(self.num_units, vocab_size, dtype=inputs.dtype) decoder = tf.contrib.seq2seq.BasicDecoder( cell, helper, initial_state, output_layer=output_layer if not fused_projection else None) outputs, state, length = tf.contrib.seq2seq.dynamic_decode(decoder) if fused_projection and output_layer is not None: logits = output_layer(outputs.rnn_output) else: logits = outputs.rnn_output # Make sure outputs have the same time_dim as inputs inputs_len = tf.shape(inputs)[1] logits = align_in_time(logits, inputs_len) if return_alignment_history: alignment_history = _get_alignment_history(state) if alignment_history is not None: alignment_history = align_in_time(alignment_history, inputs_len) return (logits, state, length, alignment_history) return (logits, state, length)
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN features_length = self.features_inputter.get_length(features) source_inputs = self.features_inputter.make_inputs(features, training=training) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.labels_inputter.vocabulary_size target_dtype = self.labels_inputter.dtype if labels is not None: sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get("scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) def _decode_inputs(inputs, length, reuse=None): with tf.variable_scope("decoder", reuse=reuse): return self.decoder.decode( inputs, length, vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self.labels_inputter.embedding, output_layer=self.output_layer, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) target_inputs = self.labels_inputter.make_inputs(labels, training=training) logits, _, _, attention = _decode_inputs(target_inputs, labels["length"]) if "alignment" in labels: outputs = { "logits": logits, "attention": attention } else: outputs = logits noisy_ids = labels.get("noisy_ids") if noisy_ids is not None and params.get("contrastive_learning"): # In case of contrastive learning, also forward the erroneous # translation to compute its log likelihood later. noisy_inputs = self.labels_inputter.make_inputs({"ids": noisy_ids}, training=training) noisy_logits = _decode_inputs(noisy_inputs, labels["noisy_length"], reuse=True)[0] if not isinstance(outputs, dict): outputs = dict(logits=outputs) outputs["noisy_logits"] = noisy_logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape(tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=self.output_layer, beam_width=beam_width, length_penalty=params.get("length_penalty", 0), maximum_iterations=params.get("maximum_iterations", 250), minimum_length=params.get("minimum_decoding_length", 0), mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=params.get("sampling_topk"), sample_temperature=params.get("sampling_temperature"), coverage_penalty=params.get("coverage_penalty", 0))) target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup(tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError("replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError("replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch(source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) align_shape = shape_list(alignment) attention = tf.reshape( alignment, [align_shape[0] * align_shape[1], align_shape[2], align_shape[3]]) # We don't have attention for </s> but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1]) replaced_target_tokens = replace_unknown_target(target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) decoding_noise = params.get("decoding_noise") if decoding_noise: sampled_length -= 1 # Ignore </s> target_tokens, sampled_length = _add_noise( target_tokens, sampled_length, decoding_noise, params.get("decoding_subword_token", "■")) sampled_length += 1 alignment = None # Invalidate alignments. predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment num_hypotheses = params.get("num_hypotheses", 1) if num_hypotheses > 0: if num_hypotheses > beam_width: raise ValueError("n_best cannot be greater than beam_width") for key, value in six.iteritems(predictions): predictions[key] = value[:, :num_hypotheses] else: predictions = None return outputs, predictions
def _dynamic_decode( self, features, encoder_outputs, encoder_state, encoder_sequence_length, tflite_run=False, ): params = self.params batch_size = tf.shape(tf.nest.flatten(encoder_outputs)[0])[0] start_ids = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) beam_size = params.get("beam_width", 1) if beam_size > 1: # Tile encoder outputs to prepare for beam search. encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, beam_size) encoder_sequence_length = tfa.seq2seq.tile_batch( encoder_sequence_length, beam_size) encoder_state = tf.nest.map_structure( lambda state: tfa.seq2seq.tile_batch(state, beam_size) if state is not None else None, encoder_state, ) # Dynamically decodes from the encoder outputs. initial_state = self.decoder.initial_state( memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, initial_state=encoder_state, ) ( sampled_ids, sampled_length, log_probs, alignment, _, ) = self.decoder.dynamic_decode( self.labels_inputter, start_ids, initial_state=initial_state, decoding_strategy=decoding.DecodingStrategy.from_params( params, tflite_mode=tflite_run), sampler=decoding.Sampler.from_params(params), maximum_iterations=params.get("maximum_decoding_length", 250), minimum_iterations=params.get("minimum_decoding_length", 0), tflite_output_size=params.get("tflite_output_size", 250) if tflite_run else None, ) if tflite_run: target_tokens = sampled_ids else: target_tokens = self.labels_inputter.ids_to_tokens.lookup( tf.cast(sampled_ids, tf.int64)) # Maybe replace unknown targets by the source tokens with the highest attention weight. if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features if tflite_run else features["tokens"] if beam_size > 1: source_tokens = tfa.seq2seq.tile_batch(source_tokens, beam_size) original_shape = tf.shape(target_tokens) if tflite_run: target_tokens = tf.squeeze(target_tokens, axis=0) output_size = original_shape[-1] unknown_token = self.labels_inputter.vocabulary_size - 1 else: target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) output_size = tf.shape(target_tokens)[1] unknown_token = constants.UNKNOWN_TOKEN align_shape = misc.shape_list(alignment) attention = tf.reshape( alignment, [ align_shape[0] * align_shape[1], align_shape[2], align_shape[3] ], ) attention = reducer.align_in_time(attention, output_size) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention, unknown_token=unknown_token) if tflite_run: target_tokens = replaced_target_tokens else: target_tokens = tf.reshape(replaced_target_tokens, original_shape) if tflite_run: if beam_size > 1: target_tokens = tf.transpose(target_tokens) target_tokens = target_tokens[:, :1] target_tokens = tf.squeeze(target_tokens) return target_tokens # Maybe add noise to the predictions. decoding_noise = params.get("decoding_noise") if decoding_noise: target_tokens, sampled_length = _add_noise( target_tokens, sampled_length, decoding_noise, params.get("decoding_subword_token", "■"), params.get("decoding_subword_token_is_spacer"), ) alignment = None # Invalidate alignments. predictions = {"log_probs": log_probs} if self.labels_inputter.tokenizer.in_graph: detokenized_text = self.labels_inputter.tokenizer.detokenize( tf.reshape(target_tokens, [batch_size * beam_size, -1]), sequence_length=tf.reshape(sampled_length, [batch_size * beam_size]), ) predictions["text"] = tf.reshape(detokenized_text, [batch_size, beam_size]) else: predictions["tokens"] = target_tokens predictions["length"] = sampled_length if alignment is not None: predictions["alignment"] = alignment # Maybe restrict the number of returned hypotheses based on the user parameter. num_hypotheses = params.get("num_hypotheses", 1) if num_hypotheses > 0: if num_hypotheses > beam_size: raise ValueError("n_best cannot be greater than beam_width") for key, value in predictions.items(): predictions[key] = value[:, :num_hypotheses] return predictions