def _build(self): self.examples_inputter.build() if EmbeddingsSharingLevel.share_target_embeddings(self.share_embeddings): self.output_layer = layers.Dense( self.labels_inputter.vocabulary_size, weight=self.labels_inputter.embedding, transpose=True, dtype=self.labels_inputter.dtype) with tf.name_scope(tf.get_variable_scope().name + "/"): self.output_layer.build([None, self.decoder.output_size])
def build(self, input_shape): super(SequenceToSequence, self).build(input_shape) output_layer = None if EmbeddingsSharingLevel.share_target_embeddings( self.share_embeddings): output_layer = layers.Dense(self.labels_inputter.vocabulary_size, weight=self.labels_inputter.embedding, transpose=True) self.decoder.initialize( vocab_size=self.labels_inputter.vocabulary_size, output_layer=output_layer)
def _build(self): self.examples_inputter.build() vocab_size = self.examples_inputter.vocabulary_size output_layer = None if self.reuse_embedding: output_layer = layers.Dense( vocab_size, weight=self.examples_inputter.embedding, transpose=True, dtype=self.examples_inputter.dtype) self.decoder.initialize(vocab_size=vocab_size, output_layer=output_layer)
def build(self, input_shape): super(LanguageModel, self).build(input_shape) vocab_size = self.examples_inputter.vocabulary_size output_layer = None if self.reuse_embedding: output_layer = layers.Dense( vocab_size, weight=self.examples_inputter.embedding, transpose=True) self.decoder.initialize(vocab_size=vocab_size, output_layer=output_layer)
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN self.examples_inputter.build() features_length = self.features_inputter.get_length(features) source_inputs = self.features_inputter.make_inputs(features, training=training) with tf.variable_scope("encoder"): encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode( source_inputs, sequence_length=features_length, mode=mode) target_vocab_size = self.labels_inputter.vocabulary_size target_dtype = self.labels_inputter.dtype output_layer = None if EmbeddingsSharingLevel.share_target_embeddings( self.share_embeddings): output_layer = layers.Dense(target_vocab_size, weight=self.labels_inputter.embedding, transpose=True, dtype=target_dtype) with tf.name_scope(tf.get_variable_scope().name + "/"): output_layer.build([None, self.decoder.output_size]) if labels is not None: target_inputs = self.labels_inputter.make_inputs(labels, training=training) with tf.variable_scope("decoder"): sampling_probability = None if mode == tf.estimator.ModeKeys.TRAIN: sampling_probability = get_sampling_probability( tf.train.get_or_create_global_step(), read_probability=params.get( "scheduled_sampling_read_probability"), schedule_type=params.get("scheduled_sampling_type"), k=params.get("scheduled_sampling_k")) logits, _, _, attention = self.decoder.decode( target_inputs, self.labels_inputter.get_length(labels), vocab_size=target_vocab_size, initial_state=encoder_state, sampling_probability=sampling_probability, embedding=self.labels_inputter.embedding, output_layer=output_layer, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, return_alignment_history=True) if "alignment" in labels: outputs = {"logits": logits, "attention": attention} else: outputs = logits else: outputs = None if mode != tf.estimator.ModeKeys.TRAIN: with tf.variable_scope("decoder", reuse=labels is not None): batch_size = tf.shape( tf.contrib.framework.nest.flatten(encoder_outputs)[0])[0] beam_width = params.get("beam_width", 1) maximum_iterations = params.get("maximum_iterations", 250) minimum_length = params.get("minimum_decoding_length", 0) sample_from = params.get("sampling_topk", 1) start_tokens = tf.fill([batch_size], constants.START_OF_SENTENCE_ID) end_token = constants.END_OF_SENTENCE_ID if beam_width <= 1: sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=output_layer, maximum_iterations=maximum_iterations, minimum_length=minimum_length, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=sample_from) else: length_penalty = params.get("length_penalty", 0) sampled_ids, _, sampled_length, log_probs, alignment = ( self.decoder.dynamic_decode_and_search( self.labels_inputter.embedding, start_tokens, end_token, vocab_size=target_vocab_size, initial_state=encoder_state, output_layer=output_layer, beam_width=beam_width, length_penalty=length_penalty, maximum_iterations=maximum_iterations, minimum_length=minimum_length, mode=mode, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, dtype=target_dtype, return_alignment_history=True, sample_from=sample_from)) target_vocab_rev = self.labels_inputter.vocabulary_lookup_reverse() target_tokens = target_vocab_rev.lookup( tf.cast(sampled_ids, tf.int64)) if params.get("replace_unknown_target", False): if alignment is None: raise TypeError( "replace_unknown_target is not compatible with decoders " "that don't return alignment history") if not isinstance(self.features_inputter, inputters.WordEmbedder): raise TypeError( "replace_unknown_target is only defined when the source " "inputter is a WordEmbedder") source_tokens = features["tokens"] if beam_width > 1: source_tokens = tf.contrib.seq2seq.tile_batch( source_tokens, multiplier=beam_width) # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) attention = tf.reshape( alignment, [-1, tf.shape(alignment)[2], tf.shape(alignment)[3]]) replaced_target_tokens = replace_unknown_target( target_tokens, source_tokens, attention) target_tokens = tf.reshape(replaced_target_tokens, original_shape) predictions = { "tokens": target_tokens, "length": sampled_length, "log_probs": log_probs } if alignment is not None: predictions["alignment"] = alignment else: predictions = None return outputs, predictions
def _call(self, features, labels, params, mode): training = mode == tf.estimator.ModeKeys.TRAIN outputs, predictions = None, None # Initialize input and output layers. self.examples_inputter.build() vocab_size = self.examples_inputter.vocabulary_size output_layer = None if self.reuse_embedding: output_layer = layers.Dense( vocab_size, weight=self.examples_inputter.embedding, transpose=True, dtype=self.examples_inputter.dtype) self.decoder.initialize(vocab_size=vocab_size, output_layer=output_layer) ids, length = features["ids"], features["length"] if mode != tf.estimator.ModeKeys.PREDICT: # For training and evaluation, forward the full sequence. logits, _ = self._decode(ids, length, training=training) outputs = dict(logits=logits) else: assert_fixed_length = tf.debugging.Assert( tf.reduce_all(tf.equal(length, tf.reduce_max(length))), [ "Language model does not support variable length contexts during " "generation, consider setting batch_size or bucket_width to 1" ]) # Run decoder one the context, if any. with tf.control_dependencies([assert_fixed_length]): context_ids, start_ids = tf.split(ids, [tf.shape(ids)[1] - 1, 1], axis=1) context_length = length - 1 batch_size = tf.shape(context_length)[0] state = tf.cond(tf.equal(tf.reduce_sum(context_length), 0), true_fn=lambda: self.decoder.get_initial_state( batch_size=batch_size, dtype=self.dtype), false_fn=lambda: self._decode( context_ids, context_length)[1], name=self.name + "/") # Force the name scope. # Iteratively decode from the last decoder state. sampled_ids, sampled_length, _ = decoder_util.greedy_decode( self._decode, tf.squeeze(start_ids, 1), constants.END_OF_SENTENCE_ID, decode_length=params.get("maximum_iterations", 250), state=state, min_decode_length=params.get("minimum_decoding_length", 0), last_step_as_input=True, sample_from=params.get("sampling_topk", 1), sample_temperature=params.get("sampling_temperature", 1)) # Build the full prediction. full_ids = tf.concat([ids, sampled_ids], 1) full_length = length + sampled_length vocab_rev = self.examples_inputter.vocabulary_lookup_reverse() tokens = vocab_rev.lookup(tf.cast(full_ids, tf.int64)) predictions = dict(tokens=tokens, length=full_length) return outputs, predictions