def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom_sharded(inputs, dp) for feature_name in contexts: with tf.variable_scope(context_modality[feature_name].name, reuse=tf.AUTO_REUSE): contexts[feature_name] = context_modality[ feature_name].bottom_sharded(contexts[feature_name], dp) contexts_list = [ contexts[feature_name][0] for feature_name in contexts ] contexts = tf.concat(contexts_list, axis=1) inputs = [tf.concat([contexts, inputs[0]], axis=1)] with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): #dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] inputs = features["inputs"] decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) #inputs = tf.expand_dims(inputs, axis=1) #if len(inputs.shape) < 5: # inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] #inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match #inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom(inputs) for context_name in contexts: contexts[context_name] = context_modality[context_name].bottom( contexts[context_name]) with tf.variable_scope("body", reuse=tf.AUTO_REUSE): encoder_output, encoder_decoder_attention_bias = self.encode( inputs, contexts, features["target_space_id"], hparams, features=features) #encoder_output = encoder_output[0] #encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match #targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom(targets) targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = self.decode( targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top(body_outputs, None) ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality story = features[babi_qa.FeatureNames.STORY] question = features[babi_qa.FeatureNames.QUESTION] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(story)[1] + common_layers.shape_list(question)[1] + decode_length) story = tf.expand_dims(story, axis=1) question = tf.expand_dims(question, axis=1) if len(story.shape) < 5: story = tf.expand_dims(story, axis=4) if len(question.shape) < 5: question = tf.expand_dims(question, axis=4) s = common_layers.shape_list(story) batch_size = s[0] story = tf.reshape(story, [s[0] * s[1], s[2], s[3], s[4]]) s = common_layers.shape_list(question) batch_size = s[0] question = tf.reshape(question, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match story = self._shard_features({babi_qa.FeatureNames.STORY: story} )[babi_qa.FeatureNames.STORY] question = self._shard_features({babi_qa.FeatureNames.QUESTION: question} )[ babi_qa.FeatureNames.QUESTION] story_modality = self._problem_hparams.input_modality[ babi_qa.FeatureNames.STORY] question_modality = self._problem_hparams.input_modality[ babi_qa.FeatureNames.QUESTION] with tf.variable_scope(story_modality.name): story = story_modality.bottom_sharded(story, dp) with tf.variable_scope(question_modality.name, reuse=(story_modality.name == question_modality.name)): question = question_modality.bottom_sharded(question, dp) with tf.variable_scope("body"): if target_modality.is_class_modality: encoder_output = dp(self.encode, story, question, features["target_space_id"], hparams) else: encoder_output, encoder_decoder_attention_bias = dp(self.encode, story, question, features["target_space_id"],hparams,features=features) encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] encoder_output = encoder_output[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d(decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets") ) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache def labels_to_logits_fn(unused_ids, unused_i, cache): """Go from labels to logits""" with tf.variable_scope("body"): body_outputs = dp(tf.expand_dims, cache.get("encoder_output"), 2) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache if target_modality.is_class_modality: ret = transformer.fast_decode(encoder_output=encoder_output, encoder_decoder_attention_bias=None, symbols_to_logits_fn=labels_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) else: ret = transformer.fast_decode(encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] target_vocab_size = self._problem_hparams.vocab_size["targets"] if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: decode_length = ( common_layers.shape_list(inputs)[1] + features.get( "decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] input_vocab_size = self._problem_hparams.vocab_size["inputs"] if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get( "inputs", modalities.get_name(input_modality))(hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get("inputs", modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if 'partial_targets' in features: partial_targets = features['partial_targets'] else: partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = ( partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) # Create tensors for encoder-decoder attention history att_cache = {"attention_history": {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] for layer in range(num_layers): att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size]) def update_decoder_attention_history(cache): for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, self.attention_weights.keys()): m = re.search(r"(layer_\d+)", k) if m is None: continue cache["attention_history"][m[0]] = tf.concat( [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Overrides tensor2tensor.models.transformer.Transformer._fast_decode to let symbols_to_logits_fn return multiple things. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "body_output": tensor of size [batch_size, <= decode_length, hidden_size] (or [batch_size, top_beams, <= decode_length, hidden_size]) giving the raw output of the Transformer decoder corresponding to the predicted sequences "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if isinstance(target_modality, dict): primary_target_feature = self._problem_hparams.primary_target_modality primary_target_modality = target_modality[primary_target_feature] bottom_variable_scope = "%s/%s" % (primary_target_modality.name, primary_target_feature) else: primary_target_feature = "targets" primary_target_modality = target_modality bottom_variable_scope = target_modality.name if self.has_input: inputs = features["inputs"] if primary_target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features[primary_target_feature] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({primary_target_feature: targets})[primary_target_feature] with tf.variable_scope(bottom_variable_scope): targets = primary_target_modality.targets_bottom_sharded( targets, dp)[0] targets = common_layers.flatten4d3d(targets) # At step 0, targets will have 0 size, and instead we want to # create an embedding of all-zero, corresponding to the start symbol # this matches what transformer_prepare_decoder does to the target # outputs during training targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] logits = self._symbols_to_logits_fn(targets, features, bias, cache) logits = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(logits)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) logits = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: logits) return logits, cache cache = dict() infer_out = dict() if encoder_output is not None: padding_mask = 1. - common_attention.attention_bias_to_padding( encoder_decoder_attention_bias) masked_encoded_output = encoder_output * tf.expand_dims( padding_mask, axis=2) infer_out["encoded_inputs"] = tf.reduce_sum(masked_encoded_output, axis=1) self._prepare_decoder_cache(batch_size, features, cache) ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=primary_target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=cache) infer_out.update(ret) if "cache" in ret: infer_out.update(ret["cache"]) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: infer_out["outputs"] = infer_out[ "outputs"][:, partial_targets_length:] else: infer_out["outputs"] = infer_out[ "outputs"][:, :, partial_targets_length:] return infer_out