def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Copied from tensor2tensor.models.transformer.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) #if hparams.pos == "timing": # if targets_position is not None: # decoder_input = common_attention.add_timing_signal_1d_given_position( # decoder_input, targets_position) # else: # decoder_input = common_attention.add_timing_signal_1d(decoder_input) raw_decoder_input = common_layers.shift_right(features['targets_raw']) terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias( raw_decoder_input, hparams, decoder_self_attention_bias) raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1]) pos_signals = generate_positional_signals(raw_decoder_input, hparams, terminal_decoder_bias, nonterminal_decoder_bias) pos_embeddings = generate_positional_embeddings(pos_signals, hparams.decoder_pos, hparams) if "sum" in hparams.decoder_pos_integration: decoder_input = decoder_input + pos_embeddings elif "ffn" in hparams.decoder_pos_integration: with tf.variable_scope("decoder_pos_ffn"): decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2) decoder_input = transformer_ffn_layer(decoder_input, hparams, conv_padding="LEFT") return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias, nonterminal_decoder_bias, pos_signals)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d( decoder_input) if hparams.activation_dtype == "bfloat16": decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. Args: inputs: Tensor with shape [batch, memory_length, depth] target_space: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ignore_padding = get_ignore_padding(inputs) encoder_self_attention_bias = ignore_padding # Bias for self-attention to encourage attention to close positions. if hparams.proximity_bias: encoder_self_attention_bias += comm_attn.attention_bias_proximal( length=tf.shape(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( x=target_space, vocab_size=32, dense_size=inputs.shape.as_list[-1], name='target_space_embedding') emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) # Question: wat encoder_input = inputs + emb_target_space if hparams.pos == 'timing': encoder_input = comm_attn.add_timing_signal_1d(encoder_input) # Putting this here since always called immediately after... encoder_input = with_dropout(encoder_input, hparams) return EncoderState(input=encoder_input, self_attn_bias=encoder_self_attention_bias, decoder_attn_bias=ignore_padding, output=None)
def transformer_prepare_encoder(inputs, target_space, hparams): """Copied from tensor2tensor.models.transformer.""" ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", use_eager_mode=hparams.use_eager_mode) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space # random_uniform_mask = tf.expand_dims(tf.to_float(tf.to_int32(tf.random_uniform([tf.shape(encoder_input)[0], tf.shape(encoder_input)[1]]) < hparams.mask_noise_prob)), axis=2) # encoder_input = encoder_input * (1 - random_uniform_mask) if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = (comm_attn.attention_bias_lower_triangle( tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += comm_attn.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_left_3d(targets) if hparams.pos == 'timing': decoder_input = comm_attn.add_timing_signal_1d(decoder_input) # Putting this here since always called immediately after... decoder_input = with_dropout(decoder_input, hparams) return DecoderState(input=decoder_input, self_attn_bias=decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], # summarize=1000, message="decoder_input") # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], # summarize=1000, message="decoder_self_attention_bias") return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment(targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=tf.bfloat16 if hparams.activation_dtype == "bfloat16" else tf.float32) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) if hparams.activation_dtype == "bfloat16": encoder_self_attention_bias = tf.cast(encoder_self_attention_bias, tf.bfloat16) encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, tf.bfloat16) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0, sentence_cache=None): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list( inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. # In this case, features["inputs"] contains partial targets. # We force the outputs to begin with these sequences. encoder_output = None encoder_decoder_attention_bias = None partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3]) partial_targets_length = common_layers.shape_list( partial_targets)[1] decode_length += partial_targets_length batch_size = tf.shape(partial_targets)[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache, body_outputs ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, sentence_cache=self.sentence_cache, cache_flag=self.cache_flag) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """ Fast decoding Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. :param features: a map of string to model features. :param decode_length: :param beam_size: beam search size :param top_beams: an integer, how many of the beams to return :param alpha: :return: """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality assert self.has_input, "problems for dual-transformer must have inputs" # decode with an input source(needs encoder outputs) wav_inputs = features["wav_inputs"] txt_inputs = features["txt_inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(wav_inputs)[1] + features.get("decode_length", decode_length)) wav_inputs = tf.expand_dims(wav_inputs, axis=1) txt_inputs = tf.expand_dims(txt_inputs, axis=1) if len(wav_inputs.shape) < 5: wav_inputs = tf.expand_dims(wav_inputs, axis=4) if len(txt_inputs.shape) < 5: txt_inputs = tf.expand_dims(txt_inputs, axis=4) s = common_layers.shape_list(wav_inputs) batch_size = s[0] wav_inputs = tf.reshape(wav_inputs, [s[0] * s[1], s[2], s[3], s[4]]) txt_inputs = tf.reshape(txt_inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match wav_inputs = self._shard_features({"wav_inputs": wav_inputs})["wav_inputs"] wav_input_modality = self._problem_hparams.input_modality["wav_inputs"] txt_inputs = self._shard_features({"txt_inputs": txt_inputs})["txt_inputs"] txt_input_modality = self._problem_hparams.input_modality["txt_inputs"] with tf.variable_scope(wav_input_modality.name): wav_inputs = wav_input_modality.bottom_sharded(wav_inputs, dp) with tf.variable_scope(txt_input_modality.name): txt_inputs = txt_input_modality.bottom_sharded(txt_inputs, dp) with tf.variable_scope("body"): wav_enc_output, wav_enc_dec_attention_bias, \ txt_enc_output,txt_enc_dec_attention_bias = dp( self.dual_encode, wav_inputs, txt_inputs, features["target_space_id"], hparams, features=features) wav_enc_output = wav_enc_output[0] txt_enc_output = txt_enc_output[0] wav_enc_dec_attention_bias = wav_enc_dec_attention_bias[0] txt_enc_dec_attention_bias = txt_enc_dec_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.dual_decode, targets, cache.get("wav_enc_outputs"), cache.get("txt_enc_outputs"), cache.get("wav_enc_dec_attention_bias"), cache.get("txt_enc_dec_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache ret = fast_decode( wav_encoder_output=wav_enc_output, txt_encoder_output=txt_enc_output, wav_enc_dec_attention_bias=wav_enc_dec_attention_bias, txt_enc_dec_attention_bias=txt_enc_dec_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache return fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] target_vocab_size = self._problem_hparams.vocab_size["targets"] if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: decode_length = ( common_layers.shape_list(inputs)[1] + features.get( "decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] input_vocab_size = self._problem_hparams.vocab_size["inputs"] if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor modality_name = hparams.name.get( "inputs", modalities.get_name(input_modality))(hparams, input_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get("inputs", modalities.get_bottom(input_modality)) inputs = dp(bottom, inputs, hparams, input_vocab_size) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if 'partial_targets' in features: partial_targets = features['partial_targets'] else: partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = ( partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): bottom = hparams.bottom.get( "targets", modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) # GO embeddings are all zero, this is because transformer_prepare_decoder # Shifts the targets along by one for the input which pads with zeros. # If the modality already maps GO to the zero embeddings this is not # needed. targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) # Create tensors for encoder-decoder attention history att_cache = {"attention_history": {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] for layer in range(num_layers): att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size]) def update_decoder_attention_history(cache): for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, self.attention_weights.keys()): m = re.search(r"(layer_\d+)", k) if m is None: continue cache["attention_history"][m[0]] = tf.concat( [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, last_position_only=True, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. last_position_only: MUST be true for fast decoding! beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: ValueError: If last_position_only if False NotImplementedError: If there are multiple data shards. """ if not last_position_only: raise ValueError( "Fast decoding only deals with the last positions!") if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = tf.shape(inputs)[0] target_modality = self._problem_hparams.target_modality if t2t_model.is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = tf.shape(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, _ = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Overrides tensor2tensor.models.transformer.Transformer._fast_decode to let symbols_to_logits_fn return multiple things. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. Returns: A dict of decoding results { "body_output": tensor of size [batch_size, <= decode_length, hidden_size] (or [batch_size, top_beams, <= decode_length, hidden_size]) giving the raw output of the Transformer decoder corresponding to the predicted sequences "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if isinstance(target_modality, dict): primary_target_feature = self._problem_hparams.primary_target_modality primary_target_modality = target_modality[primary_target_feature] bottom_variable_scope = "%s/%s" % (primary_target_modality.name, primary_target_feature) else: primary_target_feature = "targets" primary_target_modality = target_modality bottom_variable_scope = target_modality.name if self.has_input: inputs = features["inputs"] if primary_target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features[primary_target_feature] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({primary_target_feature: targets})[primary_target_feature] with tf.variable_scope(bottom_variable_scope): targets = primary_target_modality.targets_bottom_sharded( targets, dp)[0] targets = common_layers.flatten4d3d(targets) # At step 0, targets will have 0 size, and instead we want to # create an embedding of all-zero, corresponding to the start symbol # this matches what transformer_prepare_decoder does to the target # outputs during training targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] logits = self._symbols_to_logits_fn(targets, features, bias, cache) logits = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(logits)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) logits = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: logits) return logits, cache cache = dict() infer_out = dict() if encoder_output is not None: padding_mask = 1. - common_attention.attention_bias_to_padding( encoder_decoder_attention_bias) masked_encoded_output = encoder_output * tf.expand_dims( padding_mask, axis=2) infer_out["encoded_inputs"] = tf.reduce_sum(masked_encoded_output, axis=1) self._prepare_decoder_cache(batch_size, features, cache) ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=primary_target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, cache=cache) infer_out.update(ret) if "cache" in ret: infer_out.update(ret["cache"]) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: infer_out["outputs"] = infer_out[ "outputs"][:, partial_targets_length:] else: infer_out["outputs"] = infer_out[ "outputs"][:, :, partial_targets_length:] return infer_out
def transformer_prepare_encoder(inputs, target_space, hparams, features=None, type_ids=None, num_types=None, reuse_target_embedding=tf.AUTO_REUSE): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. type_ids: optional, an int64 Tensor of shape [batch, length] that allows for adding type embeddings, similar to positional embeddings. num_types: optional, an int that decides the number of types in type_ids. reuse_target_embedding: option to reuse variable name in the case that symbol modalities are reused between inputs/targets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if target_space is not None and hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=hparams.get("activation_dtype", "float32"), reuse=reuse_target_embedding) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) elif hparams.pos == "timing_from_features": encoder_input = common_attention.add_timing_signals_from_features( encoder_input, features, hparams.position_features) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) # Add type embeddings if type_ids is not None: if not num_types: raise ValueError("Need to set num_types as well.") encoder_input = common_attention.add_positional_embedding( encoder_input, num_types, "inputs_type_embedding", type_ids) encoder_self_attention_bias = common_layers.cast_like( encoder_self_attention_bias, encoder_input) encoder_decoder_attention_bias = common_layers.cast_like( encoder_decoder_attention_bias, encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. sg: inputs here have been flattened to 3d [batch, height, width, embed_size] -> [batch, height*width, embed_size] target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) # sg: [batch_size, sentence_len] ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) # sg: [batch_size, 1, 1, sentence_len] # an bias tensor to be added to attention logits # for padded words, the biases equal -1e9 # non padded words equal 0 encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, # sg: 32 vocab_size (comments in fun, may be not exactly) # this is because at current time t2t only have # SpaceID in problem.py from 1 to 32 ishape_static[-1], # sg: embedding dimension name="target_space_embedding", dtype=tf.bfloat16 if hparams.activation_dtype == "bfloat16" else tf.float32) # sg: [1,128] a dense vector to represent SpaceID emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) # sg: [1,1,128] encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) if hparams.activation_dtype == "bfloat16": encoder_self_attention_bias = tf.cast(encoder_self_attention_bias, tf.bfloat16) encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, tf.bfloat16) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. # In this case, features["inputs"] contains partial targets. # We force the outputs to begin with these sequences. encoder_output = None encoder_decoder_attention_bias = None partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3]) partial_targets_length = common_layers.shape_list(partial_targets)[1] decode_length += partial_targets_length batch_size = tf.shape(partial_targets)[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot(tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality story = features[babi_qa.FeatureNames.STORY] question = features[babi_qa.FeatureNames.QUESTION] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(story)[1] + common_layers.shape_list(question)[1] + decode_length) story = tf.expand_dims(story, axis=1) question = tf.expand_dims(question, axis=1) if len(story.shape) < 5: story = tf.expand_dims(story, axis=4) if len(question.shape) < 5: question = tf.expand_dims(question, axis=4) s = common_layers.shape_list(story) batch_size = s[0] story = tf.reshape(story, [s[0] * s[1], s[2], s[3], s[4]]) s = common_layers.shape_list(question) batch_size = s[0] question = tf.reshape(question, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match story = self._shard_features({babi_qa.FeatureNames.STORY: story} )[babi_qa.FeatureNames.STORY] question = self._shard_features({babi_qa.FeatureNames.QUESTION: question} )[ babi_qa.FeatureNames.QUESTION] story_modality = self._problem_hparams.input_modality[ babi_qa.FeatureNames.STORY] question_modality = self._problem_hparams.input_modality[ babi_qa.FeatureNames.QUESTION] with tf.variable_scope(story_modality.name): story = story_modality.bottom_sharded(story, dp) with tf.variable_scope(question_modality.name, reuse=(story_modality.name == question_modality.name)): question = question_modality.bottom_sharded(question, dp) with tf.variable_scope("body"): if target_modality.is_class_modality: encoder_output = dp(self.encode, story, question, features["target_space_id"], hparams) else: encoder_output, encoder_decoder_attention_bias = dp(self.encode, story, question, features["target_space_id"],hparams,features=features) encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] encoder_output = encoder_output[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d(decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets") ) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache def labels_to_logits_fn(unused_ids, unused_i, cache): """Go from labels to logits""" with tf.variable_scope("body"): body_outputs = dp(tf.expand_dims, cache.get("encoder_output"), 2) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache if target_modality.is_class_modality: ret = transformer.fast_decode(encoder_output=encoder_output, encoder_decoder_attention_bias=None, symbols_to_logits_fn=labels_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) else: ret = transformer.fast_decode(encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=_features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims( common_layers.sample_with_temperature(logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom_sharded(inputs, dp) for feature_name in contexts: with tf.variable_scope(context_modality[feature_name].name, reuse=tf.AUTO_REUSE): contexts[feature_name] = context_modality[ feature_name].bottom_sharded(contexts[feature_name], dp) contexts_list = [ contexts[feature_name][0] for feature_name in contexts ] contexts = tf.concat(contexts_list, axis=1) inputs = [tf.concat([contexts, inputs[0]], axis=1)] with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space #if hparams.pos == "timing": # if inputs_position is not None: # encoder_input = common_attention.add_timing_signal_1d_given_position( # encoder_input, inputs_position) # else: # encoder_input = common_attention.add_timing_signal_1d(encoder_input) raw_encoder_input = tf.squeeze(features['inputs_raw'], axis=[-2, -1]) pos_signals = generate_positional_signals(raw_encoder_input, hparams) pos_embeddings = generate_positional_embeddings(pos_signals, hparams.encoder_pos, hparams) if "sum" in hparams.encoder_pos_integration: encoder_input = encoder_input + pos_embeddings elif "ffn" in hparams.encoder_pos_integration: with tf.variable_scope("encoder_pos_ffn"): encoder_input = tf.concat([encoder_input, pos_embeddings], axis=2) encoder_input = transformer_ffn_layer(encoder_input, hparams, conv_padding="SAME") return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def _greedy_infer(self, features, decode_length, last_position_only=True): """Fast version of greedy decoding. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. last_position_only: MUST be true for fast decoding! Returns: samples: [batch_size, input_length + decode_length] logits: Not returned losses: Not returned Raises: ValueError: If last_position_only if False NotImplementedError: If there are multiple data shards. """ if not last_position_only: raise ValueError( "Fast decoding only deals with the last positions!") if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = tf.shape(inputs)[0] target_modality = self._problem_hparams.target_modality if t2t_model.is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = tf.shape(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams) if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, encoder_output[0], encoder_decoder_attention_bias[0], bias, hparams, cache) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]) def inner_loop(i, next_id, decoded_ids, cache): logits = symbols_to_logits_fn(next_id, i, cache) next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), { "layer_%d" % layer: { "k": tf.TensorShape([None, None, key_channels]), "v": tf.TensorShape([None, None, value_channels]), } for layer in range(num_layers) } ]) return decoded_ids, None, None
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): #dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] inputs = features["inputs"] decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) #inputs = tf.expand_dims(inputs, axis=1) #if len(inputs.shape) < 5: # inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] #inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match #inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom(inputs) for context_name in contexts: contexts[context_name] = context_modality[context_name].bottom( contexts[context_name]) with tf.variable_scope("body", reuse=tf.AUTO_REUSE): encoder_output, encoder_decoder_attention_bias = self.encode( inputs, contexts, features["target_space_id"], hparams, features=features) #encoder_output = encoder_output[0] #encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match #targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom(targets) targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = self.decode( targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top(body_outputs, None) ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) return ret
def _fast_decode( self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0, preprocess_targets_method=None, ): if self._num_datashards != 1: raise NotImplementedError( 'Fast decoding only supports a single shard.') dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality['targets'] target_vocab_size = self._problem_hparams.vocab_size['targets'] if target_vocab_size is not None and hasattr(hparams, 'vocab_divisor'): target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor target_tag_modality = self._problem_hparams.modality[ 'targets_error_tag'] target_tag_vocab_size = self._problem_hparams.vocab_size[ 'targets_error_tag'] if target_tag_vocab_size is not None and hasattr( hparams, 'vocab_divisor'): target_tag_vocab_size += ( -target_tag_vocab_size) % hparams.vocab_divisor if 'targets_segmentation' in features: raise NotImplementedError( 'Decoding not supported on packed datasets ' ' If you want to decode from a dataset, use the non-packed version' ' of the dataset when decoding.') if self.has_input: inputs_shape = common_layers.shape_list(features['inputs']) if (target_modality == modalities.ModalityType.CLASS_LABEL or self._problem_hparams.get('regression_targets')): decode_length = 1 else: decode_length = inputs_shape[1] + features.get( 'decode_length', decode_length) batch_size = inputs_shape[0] inputs = self._prepare_inputs_for_decode(features) with tf.variable_scope('body'): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features['target_space_id'], hparams, features=features, ) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = features.get('partial_targets') else: encoder_output = None encoder_decoder_attention_bias = None partial_targets = features.get('inputs') if partial_targets is None: partial_targets = features['targets'] assert partial_targets is not None if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = partial_targets_length + features.get( 'decode_length', decode_length) batch_size = partial_targets_shape[0] if hparams.pos == 'timing': positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == 'timing_from_features': positional_encoding = common_attention.add_timing_signals_from_features( tf.zeros([1, decode_length, hparams.hidden_size]), features, hparams.position_features, ) elif hparams.pos == 'emb': positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, 'body/targets_positional_embedding', None, ) else: positional_encoding = None def preprocess_targets(targets, i): targets = self._shard_features({'targets': targets})['targets'] modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name + '/targets'): bottom = hparams.bottom.get( 'targets', modalities.get_targets_bottom(target_modality)) targets = dp(bottom, targets, hparams, target_vocab_size)[0] targets = common_layers.flatten4d3d(targets) if not self.get_decode_start_id(): targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets, ) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets def preprocess_targets_tag_method(targets, i): targets = self._shard_features({'targets_error_tag': targets})['targets_error_tag'] modality_name = hparams.name.get( 'targets_error_tag', modalities.get_name(target_tag_modality))( hparams, target_tag_vocab_size) with tf.variable_scope(modality_name + '/targets_error_tag'): bottom = hparams.bottom.get( 'targets_error_tag', modalities.get_targets_bottom(target_tag_modality), ) targets = dp(bottom, targets, hparams, target_tag_vocab_size)[0] targets = common_layers.flatten4d3d(targets) if not self.get_decode_start_id(): targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets, ) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = common_attention.attention_bias_lower_triangle( decode_length) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) att_cache = {'attention_history': {}} num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers if encoder_output is not None: att_batch_size, enc_seq_length = common_layers.shape_list( encoder_output)[0:2] for layer in range(num_layers): att_cache['attention_history']['layer_%d' % layer] = tf.zeros( [att_batch_size, hparams.num_heads, 0, enc_seq_length]) def update_decoder_attention_history(cache): for k in [ x for x in self.attention_weights if 'decoder' in x and 'self' not in x and 'logits' not in x ]: idx = k.find('layer_') if idx < 0: continue # Get layer number from the string name. layer_nbr = k[idx + 6:] idx = 0 while (idx + 1 < len(layer_nbr) and layer_nbr[:idx + 1].isdigit()): idx += 1 layer_nbr = 'layer_%d' % int(layer_nbr[:idx]) if layer_nbr in cache['attention_history']: cache['attention_history'][layer_nbr] = tf.concat( [ cache['attention_history'][layer_nbr], self.attention_weights[k], ], axis=2, ) if not preprocess_targets_method: preprocess_targets_method = preprocess_targets def symbols_to_logits_fn(ids, ids_tag, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets_method(targets, i) ids_tag = ids_tag[:, -1:] targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2), axis=3) targets_tag = preprocess_targets_tag_method(targets_tag, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope('body'): with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = targets preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=None) layer_inputs = [ tf.concat(preproc(x), axis=0), tf.concat(preproc(targets_tag), axis=0), ] y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=features_to_nonpadding( features, 'targets'), losses=None, cache=cache, decode_loop_step=None, layer_collection=None, ) targets = common_layers.layer_postprocess( x, y, hparams) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) body_outputs = dp( self.decode, targets, cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding(features, 'targets'), )[0] body_outputs, logits_tag = dp( self._prediction_cascade_predict, hparams, features_to_nonpadding(features, 'targets'), cache.get('encoder_decoder_attention_bias'), cache.get('encoder_output'), body_outputs, ) logits_tag = logits_tag[0]['targets_error_tag'] if hparams.middle_prediction: with tf.variable_scope('after_prediction'): body_outputs = dp( self.decode, targets + body_outputs[0], cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding( features, 'targets'), ) update_decoder_attention_history(cache) modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope('targets/' + modality_name): top = hparams.top.get('targets', modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2]) if partial_targets is not None: vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9, ) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret, ) logits_tag = tf.squeeze(logits_tag, axis=[1]) return ret, logits_tag, cache sos_id = self.get_decode_start_id() or 0 eos_id = self.get_decode_end_id() or beam_search.EOS_ID temperature = features.get('sampling_temp', getattr(hparams, 'sampling_temp', 0.0)) top_k = features.get('sampling_keep_top_k', getattr(hparams, 'sampling_keep_top_k', -1)) ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_vocab_size, init_cache_fn=_init_transformer_cache, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, sos_id=sos_id, eos_id=eos_id, sampling_temperature=temperature, top_k=top_k, cache=att_cache, ) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret['outputs'] = ret['outputs'][:, partial_targets_length:] else: ret['outputs'] = ret['outputs'][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """ Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=transformer._features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) decoded_ids = decoded_ids[:, :, 1:] # do roulette wheel selection or inverse roulette wheel selection if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse": if self._hparams.roulette == "Normal": probabilities = tf.pow(tf.constant(2.0), scores) start = 0 else: probabilities = tf.subtract( tf.constant(1.0), tf.pow(tf.constant(2.0), scores)) start = beam_size - self._hparams.roulette_beam_size summ = tf.reduce_sum(probabilities) ex_probs = tf.divide(probabilities, summ) #ex_probs=tf.nn.softmax(probabilities) # sample a number between 0 and 1 wheel = tf.random_uniform([1]) upper_bound = tf.constant(0.0) # change this as well if using inverse for i in range(start, self._hparams.roulette_beam_size): upper_bound = tf.add(ex_probs[:, i], upper_bound) truthValue = tf.squeeze( tf.logical_and(wheel >= upper_bound - ex_probs[:, i], wheel <= upper_bound)) decoded_ids, scores, i = tf.cond( truthValue, lambda: (decoded_ids[:, i, :], scores[:, i], beam_size), lambda: (decoded_ids, scores, i)) else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims(common_layers.sample_with_temperature( logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _fast_decode_tpu(self, features, decode_length, beam_size=1): """Fast decoding. Implements only greedy decoding on TPU. Args: features: A map of string to model features. decode_length: An integer, how many additional timesteps to decode. beam_size: An integer, number of beams. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) }. Raises: NotImplementedError: If there are multiple data shards or beam_size > 1. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: A tensor, inputs ids to the decoder. [batch_size, 1]. i: An integer, Step number of the decoding loop. Returns: A tensor, processed targets [batch_size, 1, hidden_dim]. """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: positional_encoding_shape = positional_encoding.shape.as_list() targets += tf.slice(positional_encoding, [0, i, 0], [ positional_encoding_shape[0], 1, positional_encoding_shape[2] ]) return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_tpu_fn(ids, i, cache): """Go from ids to logits for next symbol on TPU. Args: ids: A tensor, symbol IDs. i: An integer, step number of the decoding loop. Only used for inference on TPU. cache: A dict, containing tensors which are the results of previous attentions, used for fast decoding. Returns: ret: A tensor, computed logits. cache: A dict, containing tensors which are the results of previous attentions, used for fast decoding. """ ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias_shape = decoder_self_attention_bias.shape.as_list() bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, i, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile( tf.slice(partial_targets, [0, i], [partial_targets.shape.as_list()[0], 1]), [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode_tpu( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_tpu_fn, hparams=hparams, decode_length=decode_length, beam_size=beam_size, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret