def bilinear_attention_layer(self, document, query, doc_mask): """ can be used here to replace summ operation # document: (B, D, 2h) # query: (B, 2h) # return: (B, 2h) """ num_units = int(document.shape[-1]) with tf.variable_scope('bilinear_att') as vs: W_att = tf.get_variable('W_bilinear', shape=(num_units, num_units), dtype=tf.float32, initializer=tf.random_uniform_initializer(-0.01, 0.01)) M = tf.expand_dims(tf.matmul(query, W_att), axis=1) alpha = tf.nn.softmax(softmax_mask(tf.reduce_sum(document * M, axis=2), tf.to_float(doc_mask))) return tf.reduce_sum(document * tf.expand_dims(alpha, axis=2), axis=1)
def forward(self): # in: c, q, c_mask, q_mask, ch, qh, y1, y2 # out: yp1, yp2, loss config = self.config N, PL, QL, CL, d, dc, dg = config.batch_size, self.c_maxlen, self.q_maxlen, config.char_limit, config.hidden, config.char_dim, config.char_hidden gru = cudnn_gru if config.use_cudnn else native_gru with tf.variable_scope('emb'): with tf.variable_scope('char'): ch_emb = tf.reshape( tf.nn.embedding_lookup(self.char_mat, self.ch), [N * PL, CL, dc]) qh_emb = tf.reshape( tf.nn.embedding_lookup(self.char_mat, self.qh), [N * QL, CL, dc]) ch_emb = dropout(ch_emb, keep_prob=config.keep_prob, is_train=self.is_train) qh_emb = dropout(qh_emb, keep_prob=config.keep_prob, is_train=self.is_train) cell_fw = tf.contrib.rnn.GRUCell(dg) cell_bw = tf.contrib.rnn.GRUCell(dg) _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32) ch_emb = tf.concat([state_fw, state_bw], axis=1) _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32) qh_emb = tf.concat([state_fw, state_bw], axis=1) qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg]) ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg]) with tf.variable_scope('word'): c_emb = tf.nn.embedding_lookup(self.word_mat, self.c) q_emb = tf.nn.embedding_lookup(self.word_mat, self.q) c_emb = tf.concat([c_emb, ch_emb], axis=2) q_emb = tf.concat([q_emb, qh_emb], axis=2) # context encoding: method1 with tf.variable_scope('encoding'): rnn = gru(num_layers=3, num_units=d, batch_size=N, input_size=c_emb.get_shape().as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) c = rnn(c_emb, seq_len=self.c_len, concat=True, keep_origin_input=True) q = rnn(q_emb, seq_len=self.q_len, concat=True, keep_origin_input=True) with tf.variable_scope('attention'): qc_att = dot_attention(inputs=c, memory=q, hidden_size=d, mask=self.q_mask, keep_prob=config.keep_prob, is_train=self.is_train, scope='qc_dot_att') rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=qc_att.get_shape().as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train, scope='qc') cq_att = dot_attention(inputs=q, memory=c, hidden_size=d, mask=self.c_mask, keep_prob=config.keep_prob, is_train=self.is_train, scope='cq_dot_att') rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=cq_att.get_shape().as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train, scope='cq') c = rnn(qc_att, seq_len=self.c_len, keep_origin_input=False) q = rnn(cq_att, seq_len=self.q_len, keep_origin_input=False) # seq_length = self.q_len # idx = tf.concat( # [tf.expand_dims(tf.range(tf.shape(q)[0]), axis=1), # tf.expand_dims(seq_length - 1, axis=1)], axis=1) # # (B, 2h) # q_state = tf.gather_nd(q, idx) with tf.variable_scope('hybrid'): # B * N * Q doc_qry_mask = tf.keras.backend.batch_dot(tf.expand_dims(tf.cast(self.c_mask, tf.float32), 2), tf.expand_dims(tf.cast(self.q_mask, tf.float32), 1), axes=[2, 1]) # (B, D, Q, 2h) doc_expand_embed = tf.tile(tf.expand_dims(c, 2), [1, 1, self.q_maxlen, 1]) # (B, D, Q, 2h) qry_expand_embed = tf.tile(tf.expand_dims(q, 1), [1, self.c_maxlen, 1, 1]) doc_qry_dot_embed = doc_expand_embed * qry_expand_embed # (B, D, Q, 6h) doc_qry_embed = tf.concat([doc_expand_embed, qry_expand_embed, doc_qry_dot_embed], axis=3) # attention way num_units = doc_qry_embed.shape[-1] with tf.variable_scope('bi_attention'): w = tf.get_variable('W_att', shape=(num_units, 1), dtype=tf.float32, initializer=tf.random_uniform_initializer(-0.01, 0.01)) # (B, D, Q) S = tf.matmul(tf.reshape(doc_qry_embed, (-1, doc_qry_embed.shape[-1])), w) S = tf.reshape(S, (N, self.c_maxlen, self.q_maxlen)) # context2query, (B, D, 2h) c2q = tf.keras.backend.batch_dot(tf.nn.softmax(softmax_mask(S, doc_qry_mask), dim=2), q) c2q_gated = c2q * c with tf.variable_scope('gated_attention'): # Gated Attention g_doc_qry_att = tf.keras.backend.batch_dot(c, tf.transpose(q, (0, 2, 1))) # B * N * Q alphas = tf.nn.softmax(softmax_mask(g_doc_qry_att, doc_qry_mask), dim=2) q_rep = tf.keras.backend.batch_dot(alphas, q) # B x N x 2D d_gated = c * q_rep G = tf.concat([c, c2q, q_rep, c2q_gated, d_gated], axis=-1) # G = tf.nn.relu(dense(G, d * 2)) with tf.variable_scope('match'): G = dot_attention(inputs=G, memory=G, hidden_size=d, mask=self.c_mask, keep_prob=config.keep_prob, is_train=self.is_train) rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=G.get_shape().as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) doc_encoding = rnn(G, seq_len=self.c_len, concat=False) with tf.variable_scope('pointer'): # Use self-attention or bilinear attention init = summ(q, d, mask=self.q_mask, keep_prob=config.keep_prob, is_train=self.is_train) # init = self.bilinear_attention_layer(c, q_state, self.c_mask) pointer = ptr_layer(batch_size=N, hidden_size=init.get_shape().as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) logits1, logits2 = pointer(init, doc_encoding, d, self.c_mask) with tf.variable_scope('predict'): outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2), tf.expand_dims(tf.nn.softmax(logits2), axis=1)) outer = tf.matrix_band_part(outer, 0, 15) self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1) self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1) # loss1 = tf.nn.softmax_cross_entropy_with_logits_v2( # logits=logits1, labels=tf.stop_gradient(self.y1)) loss1 = tf.nn.softmax_cross_entropy_with_logits( logits=logits1, labels=tf.stop_gradient(self.y1)) # loss2 = tf.nn.softmax_cross_entropy_with_logits_v2( # logits=logits2, labels=tf.stop_gradient(self.y2)) loss2 = tf.nn.softmax_cross_entropy_with_logits( logits=logits2, labels=tf.stop_gradient(self.y2)) self.loss = tf.reduce_mean(loss1 + loss2)