def _prepare_source(): """ Pre-processes inputs to the encoder and generates the corresponding attention masks.""" # Embed pre_source_embeddings = self._embed(source_ids) with tf.variable_scope(self.name): source_embeddings = self.emb_ffn.forward(pre_source_embeddings) glove_embeddings = self.embedding_layer.get_glove_embed(source_pids) source_embeddings += glove_embeddings # Obtain length and depth of the input tensors _, time_steps, depth = get_shape_list(source_embeddings) # Transform input mask into attention mask # 恢复source_mask shape_mask = get_shape_list(source_mask) source_mask1 = tf.slice(source_mask, [0, 0, 0], [shape_mask[0], shape_mask[1], 1]) source_mask2 = tf.reshape(source_mask1, [shape_mask[0], shape_mask[1]]) inverse_mask = tf.cast(tf.equal(source_mask2, 0.0), dtype=self.float_dtype) attn_mask = inverse_mask * -1e9 # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1) # Differentiate between self-attention and cross-attention masks for further, optional modifications self_attn_mask = attn_mask cross_attn_mask = attn_mask # Add positional encodings positional_signal = get_positional_signal(time_steps, depth, self.float_dtype) source_embeddings += positional_signal # Apply dropout if self.config.transformer_dropout_embeddings > 0: source_embeddings = tf.layers.dropout(source_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) return source_embeddings, self_attn_mask, cross_attn_mask
def _additive_attn(self, queries, keys, values, attn_mask): """ Uses additive attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) def _logits_fn(query): """ Computes time-step-wise attention scores. """ query = tf.expand_dims(query, 1) return tf.reduce_sum(self.attn_weight * tf.nn.tanh(keys + query), axis=-1) # Obtain attention scores transposed_queries = tf.transpose(queries, [1, 0, 2]) # time-major attn_logits = tf.map_fn(_logits_fn, transposed_queries) attn_logits = tf.transpose(attn_logits, [1, 0, 2]) if attn_mask is not None: # Transpose and tile the mask attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [ get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1 ]) # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def _multiplicative_attn(self, queries, keys, values, attn_mask): """ Uses multiplicative attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) def _logits_fn(query): """ Computes time-step-wise attention scores. """ query = tf.expand_dims(query, 1) return tf.multiply(keys, query) # Obtain attention scores transposed_queries = tf.transpose(queries, [1, 0, 2]) # time-major # attn_logits has shape=[time_steps_q, batch_size, time_steps_k, num_features] attn_logits = tf.map_fn(_logits_fn, transposed_queries) if attn_mask is not None: transposed_mask = \ tf.transpose(tf.tile(attn_mask, [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1, 1]), [2, 0, 3, 1]) attn_logits += transposed_mask # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-2, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors expanded_values = tf.expand_dims(values, axis=1) weighted_memories = \ tf.reduce_sum(tf.multiply(tf.transpose(attn_weights, [1, 0, 2, 3]), expanded_values), axis=2) return weighted_memories
def _multiplicative_attn(self, queries, keys, values, attn_mask): """ Uses multiplicative attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) # Use multiplicative attention transposed_keys = tf.transpose(keys, [0, 2, 1]) attn_logits = tf.matmul(queries, transposed_keys) if attn_mask is not None: # Transpose and tile the mask attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [ get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1 ]) # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def gather_memories(memory_dict, gather_coordinates): """ Gathers layer-wise memory tensors corresponding to top sequences from the provided memory dictionary during beam search. """ # Initialize dicts gathered_memories = dict() # Get coordinate shapes coords_dims = get_shape_list(gather_coordinates) # Gather for layer_key in memory_dict.keys(): layer_dict = memory_dict[layer_key] gathered_memories[layer_key] = dict() for attn_key in layer_dict.keys(): attn_tensor = layer_dict[attn_key] attn_dims = get_shape_list(attn_tensor) # Not sure if this is faster than the 'memory-less' version flat_tensor = \ tf.transpose(tf.reshape(attn_tensor, [-1, coords_dims[0]] + attn_dims[1:]), [1, 0, 2, 3]) gathered_values = tf.reshape( tf.transpose(tf.gather_nd(flat_tensor, gather_coordinates), [1, 0, 2, 3]), [tf.multiply(coords_dims[1], coords_dims[0])] + attn_dims[1:]) gathered_memories[layer_key][attn_key] = gathered_values return gathered_memories
def _additive_attn(self, queries, keys, values, attn_mask): """ Uses additive attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) def _logits_fn(query): """ Computes time-step-wise attention scores. """ query = tf.expand_dims(query, 1) return tf.reduce_sum(self.attn_weight * tf.nn.tanh(keys + query), axis=-1) # Obtain attention scores transposed_queries = tf.transpose(queries, [1, 0, 2]) # time-major attn_logits = tf.map_fn(_logits_fn, transposed_queries) attn_logits = tf.transpose(attn_logits, [1, 0, 2]) if attn_mask is not None: # Transpose and tile the mask attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1]) # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def _embed(self, index_sequence): """ Embeds source-side indices to obtain the corresponding dense tensor representations. """ #重要更改 #index_sequence: (batch_size, seq_len, u_len) u_emb = self.embedding_layer.embed(index_sequence) #(batch_size, seq_len, u_len, embedding_size) shape = get_shape_list(u_emb) #加上位置编码,特指md5:[1, u_len, embedding_size] if self.config.utf8_type == "md5": md5_positional_signal = get_positional_signal(shape[2], shape[3], self.float_dtype) u_emb += md5_positional_signal #修剪为2048 input_size = self.config.pre_source_embedding_size # 默认2048 cc = input_size - shape[2]*shape[3] if self.config.pre_source_embed_cross: #似乎效果更差,且测试时bleu值异常 embsize = tf.to_int32((input_size/shape[2])) accsize = input_size % shape[2] fix_merge_emb = tf.pad(u_emb, [[0, 0], [0, 0], [0, 0], [0, tf.reduce_max([embsize-shape[3], 0])]], constant_values=1.0) fix_merge_emb = tf.slice(fix_merge_emb, [0, 0, 0, 0], [-1, -1, -1, embsize]) fix_merge_emb = tf.reshape(fix_merge_emb, [shape[0], shape[1], shape[2]*embsize]) fix_merge_emb = tf.pad(fix_merge_emb, [[0, 0], [0, 0], [0, accsize]], constant_values=1.0) else: merge_emb = tf.reshape(u_emb, [shape[0], shape[1], shape[2]*shape[3]]) #(batch_size, seq_len, u_len*embedding_size) fix_merge_emb = tf.pad(merge_emb, [[0, 0], [0, 0], [0, tf.reduce_max([cc, 0])]], constant_values=0) fix_merge_emb = tf.slice(fix_merge_emb, [0, 0, 0], [-1, -1, input_size]) return fix_merge_emb
def _prepare_source(): """ Pre-processes inputs to the encoder and generates the corresponding attention masks.""" # Embed source_embeddings = self._embed(source_ids) # Obtain length and depth of the input tensors _, time_steps, depth = get_shape_list(source_embeddings) # Transform input mask into attention mask inverse_mask = tf.cast(tf.equal(source_mask, 0.0), dtype=FLOAT_DTYPE) attn_mask = inverse_mask * -1e9 # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1) # Differentiate between self-attention and cross-attention masks for further, optional modifications self_attn_mask = attn_mask cross_attn_mask = attn_mask # Add positional encodings positional_signal = get_positional_signal(time_steps, depth, FLOAT_DTYPE) source_embeddings += positional_signal # Apply dropout if self.config.transformer_dropout_embeddings > 0: source_embeddings = tf.layers.dropout( source_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) return source_embeddings, self_attn_mask, cross_attn_mask
def _dot_product_attn(self, queries, keys, values, attn_mask, scaling_on): """ Defines the dot-product attention function; see Vasvani et al.(2017), Eq.(1). """ # query/ key/ value have shape = [batch_size, time_steps, num_heads, num_features] # Tile keys and values tensors to match the number of decoding beams; ignored if already done by fusion module num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(keys, [num_beams, 1, 1, 1]), lambda: keys) values = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(values, [num_beams, 1, 1, 1]), lambda: values) # Transpose split inputs queries = tf.transpose(queries, [0, 2, 1, 3]) values = tf.transpose(values, [0, 2, 1, 3]) attn_logits = tf.matmul(queries, tf.transpose(keys, [0, 2, 3, 1])) # Scale attention_logits by key dimensions to prevent softmax saturation, if specified if scaling_on: key_dims = get_shape_list(keys)[-1] normalizer = tf.sqrt(tf.cast(key_dims, self.float_dtype)) attn_logits /= normalizer # Optionally mask out positions which should not be attended to # attention mask should have shape=[batch, num_heads, query_length, key_length] # attn_logits has shape=[batch, num_heads, query_length, key_length] if attn_mask is not None: attn_mask = tf.cond( tf.greater(num_beams, 1), lambda: tf.tile(attn_mask, [num_beams, 1, 1, 1]), lambda: attn_mask) attn_logits += attn_mask # Calculate attention weights attn_weights = tf.nn.softmax(attn_logits) # Optionally apply dropout: if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Weigh attention values weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def _merge_from_heads(self, split_inputs): """ Inverts the _split_among_heads operation. """ # Transpose split_inputs to perform the merge along the last two dimensions of the split input split_inputs = tf.transpose(split_inputs, [0, 2, 1, 3]) # Retrieve the depth of the tensor to be merged split_inputs_dims = get_shape_list(split_inputs) split_inputs_depth = split_inputs_dims[-1] # Merge the depth and num_heads dimensions of split_inputs merged_inputs = tf.reshape(split_inputs, split_inputs_dims[:-2] + [self.num_heads * split_inputs_depth]) return merged_inputs
def _merge_from_heads(self, split_inputs): """ Inverts the _split_among_heads operation. """ # Transpose split_inputs to perform the merge along the last two dimensions of the split input split_inputs = tf.transpose(split_inputs, [0, 2, 1, 3]) # Retrieve the depth of the tensor to be merged split_inputs_dims = get_shape_list(split_inputs) split_inputs_depth = split_inputs_dims[-1] # Merge the depth and num_heads dimensions of split_inputs merged_inputs = tf.reshape( split_inputs, split_inputs_dims[:-2] + [self.num_heads * split_inputs_depth]) return merged_inputs
def get_memory_invariants(memories): """ Calculates the invariant shapes for the model memories (i.e. states of th RNN ar layer-wise attentions of the transformer). """ memory_type = type(memories) if memory_type == dict: memory_invariants = dict() for layer_id in memories.keys(): memory_invariants[layer_id] = {key: tf.TensorShape([None] * len(get_shape_list(memories[layer_id][key]))) for key in memories[layer_id].keys()} else: raise ValueError('Memory type not supported, must be a dictionary.') return memory_invariants
def gather_attn(attn): # TODO Specify second and third? shapes = {attn: ('batch_size', None, None)} tf_utils.assert_shapes(shapes) attn_dims = get_shape_list(attn) new_shape = [beam_size, batch_size_x] + attn_dims[1:] tmp = tf.reshape(attn, new_shape) flat_tensor = tf.transpose(tmp, [1, 0, 2, 3]) tmp = tf.gather_nd(flat_tensor, gather_coordinates) tmp = tf.transpose(tmp, [1, 0, 2, 3]) gathered_values = tf.reshape(tmp, attn_dims) return gathered_values
def _pre_embed(self, index_sequence): u_emb = self.embedding_layer.embed(index_sequence) #(batch_size, u_len, embedding_size) shape = get_shape_list(u_emb) if self.config.utf8_type == "md5": md5_positional_signal = get_positional_signal(shape[1], shape[2], self.float_dtype) u_emb += md5_positional_signal input_size = self.config.pre_source_embedding_size cc = input_size - shape[1]*shape[2] merge_emb = tf.reshape(u_emb, [shape[0], shape[1]*shape[2]]) #merge_emb: (batch_size, u_len*embedding_size) fix_merge_emb = tf.pad(merge_emb, [[0, 0], [0, tf.reduce_max([cc, 0])]], constant_values=1.0) fix_merge_emb = tf.slice(fix_merge_emb, [0, 0], [-1, input_size]) return fix_merge_emb
def _split_among_heads(self, inputs): """ Splits the attention inputs among multiple heads. """ # Retrieve the depth of the input tensor to be split (input is 3d) inputs_dims = get_shape_list(inputs) inputs_depth = inputs_dims[-1] # Assert the depth is compatible with the specified number of attention heads if isinstance(inputs_depth, int) and isinstance(self.num_heads, int): assert inputs_depth % self.num_heads == 0, \ ('Attention inputs depth {:d} is not evenly divisible by the specified number of attention heads {:d}' .format(inputs_depth, self.num_heads)) split_inputs = tf.reshape(inputs, inputs_dims[:-1] + [self.num_heads, inputs_depth // self.num_heads]) return split_inputs
def _multiplicative_attn(self, queries, keys, values, attn_mask): """ Uses multiplicative attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) # Use multiplicative attention transposed_keys = tf.transpose(keys, [0, 2, 1]) attn_logits = tf.matmul(queries, transposed_keys) if attn_mask is not None: # Transpose and tile the mask attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1]) # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def decode_greedy(model, do_sample=False, beam_size=0, normalization_alpha=None): # Determine size of current batch batch_size, _, _ = get_shape_list(model.source_ids) # Encode source sequences with tf.name_scope('{:s}_encode'.format(model.name)): enc_output, cross_attn_mask = model.enc.encode(model.source_pids, model.source_ids, model.source_mask) # Decode into target sequences with tf.name_scope('{:s}_decode'.format(model.name)): dec_output, scores = decode_at_test(model, model.dec, enc_output, cross_attn_mask, batch_size, beam_size, do_sample, normalization_alpha) return dec_output, scores
def _multiplicative_attn(self, queries, keys, values, attn_mask): """ Uses multiplicative attention to compute contextually enriched source-side representations. """ # Account for beam-search num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.tile(keys, [num_beams, 1, 1]) values = tf.tile(values, [num_beams, 1, 1]) def _logits_fn(query): """ Computes time-step-wise attention scores. """ query = tf.expand_dims(query, 1) return tf.multiply(keys, query) # Obtain attention scores transposed_queries = tf.transpose(queries, [1, 0, 2]) # time-major # attn_logits has shape=[time_steps_q, batch_size, time_steps_k, num_features] attn_logits = tf.map_fn(_logits_fn, transposed_queries) if attn_mask is not None: transposed_mask = \ tf.transpose(tf.tile(attn_mask, [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1, 1]), [2, 0, 3, 1]) attn_logits += transposed_mask # Compute the attention weights attn_weights = tf.nn.softmax(attn_logits, axis=-2, name='attn_weights') # Optionally apply dropout if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Obtain context vectors expanded_values = tf.expand_dims(values, axis=1) weighted_memories = \ tf.reduce_sum(tf.multiply(tf.transpose(attn_weights, [1, 0, 2, 3]), expanded_values), axis=2) return weighted_memories
def _dot_product_attn(self, queries, keys, values, attn_mask, scaling_on): """ Defines the dot-product attention function; see Vasvani et al.(2017), Eq.(1). """ # query/ key/ value have shape = [batch_size, time_steps, num_heads, num_features] # Tile keys and values tensors to match the number of decoding beams; ignored if already done by fusion module num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0] keys = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(keys, [num_beams, 1, 1, 1]), lambda: keys) values = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(values, [num_beams, 1, 1, 1]), lambda: values) # Transpose split inputs queries = tf.transpose(queries, [0, 2, 1, 3]) values = tf.transpose(values, [0, 2, 1, 3]) attn_logits = tf.matmul(queries, tf.transpose(keys, [0, 2, 3, 1])) # Scale attention_logits by key dimensions to prevent softmax saturation, if specified if scaling_on: key_dims = get_shape_list(keys)[-1] normalizer = tf.sqrt(tf.cast(key_dims, self.float_dtype)) attn_logits /= normalizer # Optionally mask out positions which should not be attended to # attention mask should have shape=[batch, num_heads, query_length, key_length] # attn_logits has shape=[batch, num_heads, query_length, key_length] if attn_mask is not None: attn_mask = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(attn_mask, [num_beams, 1, 1, 1]), lambda: attn_mask) attn_logits += attn_mask # Calculate attention weights attn_weights = tf.nn.softmax(attn_logits) # Optionally apply dropout: if self.dropout_attn > 0.0: attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training) # Weigh attention values weighted_memories = tf.matmul(attn_weights, values) return weighted_memories
def _split_among_heads(self, inputs): """ Splits the attention inputs among multiple heads. """ # Retrieve the depth of the input tensor to be split (input is 3d) inputs_dims = get_shape_list(inputs) inputs_depth = inputs_dims[-1] # Assert the depth is compatible with the specified number of attention heads if isinstance(inputs_depth, int) and isinstance(self.num_heads, int): assert inputs_depth % self.num_heads == 0, \ ('Attention inputs depth {:d} is not evenly divisible by the specified number of attention heads {:d}' .format(inputs_depth, self.num_heads)) split_inputs = tf.reshape( inputs, inputs_dims[:-1] + [self.num_heads, inputs_depth // self.num_heads]) return split_inputs
def get_memory_invariants(self, memories): """Generate shape invariants for memories. Args: memories: dictionary (see top-level class description) Returns: Dictionary of shape invariants with same structure as memories. """ with tf.name_scope(self._scope): invariants = dict() for layer_id in memories.keys(): layer_mems = memories[layer_id] invariants[layer_id] = { key: tf.TensorShape([None] * len(get_shape_list(layer_mems[key]))) for key in layer_mems.keys() } return invariants
def _prepare_source(): """ Pre-processes inputs to the encoder and generates the corresponding attention masks.""" # Embed source_embeddings = self._embed(source_ids) # Obtain length and depth of the input tensors _, time_steps, depth = get_shape_list(source_embeddings) # Transform input mask into attention mask inverse_mask = tf.cast(tf.equal(source_mask, 0.0), dtype=self.float_dtype) attn_mask = inverse_mask * -1e9 # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1) # Differentiate between self-attention and cross-attention masks for further, optional modifications self_attn_mask = attn_mask cross_attn_mask = attn_mask # Add positional encodings positional_signal = get_positional_signal(time_steps, depth, self.float_dtype) source_embeddings += positional_signal # Apply dropout if self.config.transformer_dropout_embeddings > 0: source_embeddings = tf.layers.dropout(source_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) return source_embeddings, self_attn_mask, cross_attn_mask
def get_memory_invariants(model_memories): """Calculates the invariant shapes for a model's memories. 'Memories' are states of the RNN or layer-wise attentions of the transformer. Args: model_memories: a dict of dicts. Returns: An invariant dictionary for the model. """ if type(model_memories) != dict: raise ValueError('Memory type not supported, must be a dictionary.') invariants = dict() for layer_id in model_memories.keys(): layer_mems = model_memories[layer_id] invariants[layer_id] = { key: tf.TensorShape([None]*len(get_shape_list(layer_mems[key]))) for key in model_memories[layer_id].keys() } return invariants
def get_memory_invariants(self, memories): """Generate shape invariants for memories. Args: memories: dictionary (see top-level class description) Returns: Dictionary of shape invariants with same structure as memories. """ with tf.name_scope(self._scope): d = self._model.decoder high_depth = 0 if d.high_gru_stack is None \ else len(d.high_gru_stack.grus) num_dims = len(get_shape_list(memories['base_states'])) # TODO Specify shape in full? partial_shape = tf.TensorShape([None] * num_dims) invariants = {} invariants['base_states'] = partial_shape invariants['high_states'] = [partial_shape] * high_depth return invariants
def decode_greedy(models, do_sample=False, beam_size=0, normalization_alpha=None): """Decodes a source sequence using beam search or sampling. Args: models: a list of Transformer objects. do_sample: randomly sample instead of argmax for greedy search beam_size: integer specifying the beam width. normalization_alpha: length normalization hyperparameter. Returns: A tuple (ids, scores), where ids is a Tensor with shape (batch_size, k, max_seq_len) containing k translations for each input sentence in model.inputs.x and scores is a Tensor with shape (batch_size, k) """ # Get some parameter values. For ensembling, some settings are required to # be consistent across all models but others are not. In the former case, # we assume that consistency has already been checked. For the parameters # that are allowed to vary across models, the first model's settings take # precedence. batch_size, _ = get_shape_list(models[0].source_ids) model_name = models[0].name decoder_name = models[0].dec.name from_rnn = models[0].dec.from_rnn config = models[0].dec.config float_dtype = models[0].dec.float_dtype int_dtype = models[0].dec.int_dtype vocab_size = models[0].dec.embedding_layer.get_vocab_size(), # Generate a positional signal for the longest possible output. with tf.name_scope('{:s}_decode'.format(model_name)): with tf.variable_scope(decoder_name): positional_signal = get_positional_signal( config.translation_maxlen, config.embedding_size, float_dtype) # Generate a decoding function for each model. decoding_functions = [] for model in models: assert model.name == model_name # Encode source sequences. with tf.name_scope('{:s}_encode'.format(model.name)): enc_output, cross_attn_mask = model.enc.encode(model.source_ids, model.source_mask) # Generate a model-specific decoding function. with tf.name_scope('{:s}_decode'.format(model.name)): func = generate_decoding_function(enc_output, cross_attn_mask, model.dec, positional_signal) decoding_functions.append(func) # Decode into target sequences with tf.name_scope('{:s}_decode'.format(model_name)): with tf.variable_scope(decoder_name): if beam_size > 0: # Initialize target IDs with <GO> initial_ids = tf.cast(tf.fill([batch_size], 1), dtype=int_dtype) initial_memories = [ model.dec._get_initial_memories(batch_size, beam_size=beam_size) for model in models] output_sequences, scores = _beam_search( decoding_functions, initial_ids, initial_memories, int_dtype, float_dtype, config.translation_maxlen, batch_size, beam_size, vocab_size, 0, normalization_alpha) else: # Initialize target IDs with <GO> initial_ids = tf.cast(tf.fill([batch_size, 1], 1), dtype=int_dtype) initial_memories = [ model.dec._get_initial_memories(batch_size, beam_size=1) for model in models] output_sequences, scores = greedy_search( models[0], decoding_functions[0], initial_ids, initial_memories[0], int_dtype, float_dtype, config.translation_maxlen, batch_size, 0, do_sample, time_major=False) return output_sequences, scores