def decode_greedy(self, input_batch, do_sample=False, beam_size=0): """ Generates translation hypotheses via greedy decoding. """ # Unpack inputs self.source_ids, self.target_ids_in, self.target_ids_out, self.source_mask, self.target_mask = input_batch # (Re-)generate the computational graph dec_vocab_size = self._build_graph() # Determine size of current batch batch_size, _ = get_shape_list(self.source_ids) # Encode source sequences with tf.name_scope('{:s}_encode'.format(self.name)): enc_output, cross_attn_mask = self.enc.encode(self.source_ids, self.source_mask) # Decode into target sequences with tf.name_scope('{:s}_decode'.format(self.name)): dec_output, scores, = self.dec.decode_at_test(enc_output, cross_attn_mask, batch_size, beam_size, do_sample) return dec_output, scores, dec_vocab_size
def _prepare_source(): """ Pre-processes inputs to the encoder and generates the corresponding attention masks.""" # Embed source_embeddings = self._embed(source_ids) # Obtain length and depth of the input tensors _, time_steps, depth = get_shape_list(source_embeddings) # Transform input mask into attention mask inverse_mask = tf.cast(tf.equal(source_mask, 0.0), dtype=self.float_dtype) attn_mask = inverse_mask * -1e9 # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1) # Differentiate between self-attention and cross-attention masks for further, optional modifications self_attn_mask = attn_mask cross_attn_mask = attn_mask # Add positional encodings positional_signal = get_positional_signal(time_steps, depth, self.float_dtype) source_embeddings += positional_signal # Apply dropout if self.config.dropout_embeddings > 0: source_embeddings = tf.layers.dropout(source_embeddings, rate=self.config.dropout_embeddings, training=self.training) return source_embeddings, self_attn_mask, cross_attn_mask
def decode(final_hidden, mask): padding_num = -2**32 + 1 final_hidden_shape = get_shape_list(final_hidden, expected_rank=3) batch_size = final_hidden_shape[0] seq_length = final_hidden_shape[1] hidden_size = final_hidden_shape[2] output_weights = tf.get_variable( "output_weights", [seq_length, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [seq_length], initializer=tf.zeros_initializer()) final_hidden_matrix = tf.reshape( final_hidden, [batch_size * seq_length, hidden_size]) logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [batch_size, seq_length, seq_length]) attention_raw = tf.matmul(logits, tf.transpose(logits, [0, 2, 1])) # (N, T_q, T_k) # attention_raw = tf.matmul(final_hidden, tf.transpose(final_hidden, [0, 2, 1])) # (N, T_q, T_k) mask_ = mask * tf.transpose(mask, [0, 2, 1]) paddings = tf.ones_like(mask_) * padding_num attention_raw = tf.where(tf.equal(mask_, 0), paddings, attention_raw) attention_final = tf.nn.softmax(attention_raw) # logists_final = tf.matmul(attention_final,logits) return logits, attention_final