def train( model_config, input_embeddings, source_len, output_vocab_size, output_vocab_embeddings_table, target_decode_steps, mode, input_copy_mask=None, ): """Constructs encoder and decoder transformation for training and eval. In the shapes described below, B is batch size, L is sequence length, D is the dimensionality of the model embeddings, and T is the output vocab size. Args: model_config: ModelConfig proto. input_embeddings: Tensor of shape (B, L, D) representing inputs. source_len: Tensor of shape (B) containing length of each input sequence. output_vocab_size: Size of output vocabulary. output_vocab_embeddings_table: Tensor of shape (T, D) representing table of embeddings for output symbols. target_decode_steps: DecodeSteps representing target outputs. Each tensor has shape (B, L). mode: Enum indicating model mode, TRAIN or EVAL. input_copy_mask: Mask for copying actions. Returns: Tuple of (logits, predicted_ids), where logits is a tensor of shape (B, L, T) representing model output logits, and predicted_ids is a tensor of shape (B, L) containing the derived integer IDs of the one-best output symbol. """ logits = _transformer_body( input_embeddings, source_len, target_decode_steps, mode, model_config, output_vocab_size, output_vocab_embeddings_table, input_copy_mask=input_copy_mask, ) predicted_ids = tf.to_int32(tf.argmax(logits, axis=-1)) output_decode_steps = decode_utils.get_decode_steps( predicted_ids, output_vocab_size, model_config) predictions = decode_utils.get_predictions(output_decode_steps) return logits, predictions
def _greedy_decode(input_embeddings, output_vocab_size, target_end_id, target_start_id, output_vocab_embeddings_table, source_len, model_config, mode, input_copy_mask=None, clean_output_mask=None): """Fast decoding.""" encoder_output = common_layers.linear_transform( input_embeddings, output_size=model_config.model_parameters.encoder_dims, scope="bert_to_transformer") decode_length = model_config.data_options.max_decode_length # Expand the inputs in to the beam width. def symbols_to_logits_fn(logit_indices, current_index): """Go from targets to logits.""" logit_indices = tf.expand_dims(logit_indices, 0) decode_steps = decode_utils.get_decode_steps(logit_indices, output_vocab_size, model_config) target_embeddings = _get_target_embeddings( input_embeddings, output_vocab_embeddings_table, decode_steps, model_config) decoder_output = _build_transformer_decoder( encoder_output, source_len, target_embeddings, mode, model_config, single_step_index=current_index) logits = _get_action_logits(encoder_output, decoder_output, output_vocab_embeddings_table, output_vocab_size, model_config, input_copy_mask=input_copy_mask, clean_output_mask=clean_output_mask) # Squeeze batch dimension and length dimension, as both should be 1. logits = tf.squeeze(logits, axis=[0, 1]) # Shape of logits should now be (output_vocab_size). return logits def loop_cond(i, decoded_ids, unused_logprobs): """Loop conditional that returns false to stop loop.""" return tf.logical_and( tf.reduce_all(tf.not_equal(decoded_ids, target_end_id)), tf.less(i, decode_length)) def inner_loop(i, decoded_ids, logprobs): """Decoder function invoked on each while loop iteration.""" logits = symbols_to_logits_fn(decoded_ids, i) next_id = tf.argmax(logits, axis=0) softmax = tf.nn.softmax(logits) extended_vocab_size = tf.shape(softmax)[-1] mask = tf.one_hot(next_id, extended_vocab_size) prob = tf.reduce_sum(softmax * mask) logprob = tf.log(prob) # Add one-hot values to output Tensors, since values at index > i+1 should # still be zero. logprobs += tf.one_hot(i + 1, decode_length + 1, on_value=logprob, dtype=tf.float32) decoded_ids += tf.one_hot(i + 1, decode_length + 1, on_value=next_id, dtype=tf.int64) return i + 1, decoded_ids, logprobs initial_ids = tf.zeros(dtype=tf.int64, shape=[decode_length + 1]) initial_ids += tf.one_hot(0, decode_length + 1, on_value=tf.cast(target_start_id, tf.int64)) initial_logprob = tf.zeros(dtype=tf.float32, shape=[decode_length + 1]) initial_i = tf.constant(0) initial_values = [initial_i, initial_ids, initial_logprob] _, decoded_ids, logprobs = tf.while_loop(loop_cond, inner_loop, initial_values) # Remove <START> symbol. decoded_ids = decoded_ids[1:] logprobs = logprobs[1:] # Sum logprobs to get scores for overall sequence. logprobs = tf.reduce_sum(logprobs, axis=0) # Expand decoded_ids and logprobs to reflect beam width dimension of 1. decoded_ids = tf.expand_dims(decoded_ids, 0) logprobs = tf.expand_dims(logprobs, 0) # This is the output dict that the function returns. output_decode_steps = decode_utils.get_decode_steps( decoded_ids, output_vocab_size, model_config) predictions = decode_utils.get_predictions(output_decode_steps) predictions[constants.SCORES_KEY] = logprobs return predictions
def _beam_decode(input_embeddings, alpha, output_vocab_size, target_end_id, target_start_id, output_vocab_embeddings_table, source_len, model_config, mode, beam_size, input_copy_mask=None, clean_output_mask=None): """Beam search decoding.""" # Assume batch size is 1. batch_size = 1 encoder_output = common_layers.linear_transform( input_embeddings, output_size=model_config.model_parameters.encoder_dims, scope="bert_to_transformer") decode_length = model_config.data_options.max_decode_length # Expand decoder inputs to the beam width. input_embeddings = tf.tile(input_embeddings, [beam_size, 1, 1]) encoder_output = tf.tile(encoder_output, [beam_size, 1, 1]) def symbols_to_logits_fn(current_index, logit_indices): """Go from targets to logits. Args: current_index: Integer corresponding to 0-indexed decoder step. logit_indices: Tensor of shape [batch_size * beam_width, decode_length + 1] to input to decoder. Returns: Tensor of shape [batch_size * beam_width, output_vocab_size] representing logits for the current decoder step. Raises: ValueError if inputs do not have static length. """ decode_steps = decode_utils.get_decode_steps(logit_indices, output_vocab_size, model_config) target_embeddings = _get_target_embeddings( input_embeddings, output_vocab_embeddings_table, decode_steps, model_config) decoder_output = _build_transformer_decoder( encoder_output, source_len, target_embeddings, mode, model_config, single_step_index=current_index) logits = _get_action_logits(encoder_output, decoder_output, output_vocab_embeddings_table, output_vocab_size, model_config, input_copy_mask=input_copy_mask, clean_output_mask=clean_output_mask) # Squeeze length dimension, as it should be 1. logits = tf.squeeze(logits, axis=[1]) # Shape of logits should now be: # [batch_size * beam_width, output_vocab_size]. return logits initial_ids = tf.ones([batch_size], dtype=tf.int32) * target_start_id # ids has shape: [batch_size, beam_size, decode_length] # scores has shape: [batch_size, beam_size] decode_length = model_config.data_options.max_decode_length source_length = input_embeddings.get_shape()[1] if source_length.value is None: # Fall back on using dynamic shape information. source_length = tf.shape(input_embeddings)[1] extended_vocab_size = output_vocab_size + source_length ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, extended_vocab_size, alpha, target_end_id, batch_size) # Remove start symbol from returned predicted IDs. predicted_ids = ids[:, :, 1:] # Since batch size is expected to be 1, squeeze the batch dimension. predicted_ids = tf.squeeze(predicted_ids, axis=[0]) scores = tf.squeeze(scores, axis=[0]) # This is the output dict that the function returns. output_decode_steps = decode_utils.get_decode_steps( predicted_ids, output_vocab_size, model_config) predictions = decode_utils.get_predictions(output_decode_steps) predictions[constants.SCORES_KEY] = scores return predictions