def _bert_pretrained_knowledge(self, knowledge_tokens_ph, knowledge_lengths_ph): # knowledge_phs : [batch, top_n, max_seq_len, 768] input_shape = get_shape_list(knowledge_tokens_ph, expected_rank=4) batch_size = input_shape[0] top_n = input_shape[1] knowledge_max_seq_len = input_shape[2] embedding_dim = input_shape[3] print(input_shape) knowledge_tokens_embeddded = tf.reshape( knowledge_tokens_ph, shape=[-1, knowledge_max_seq_len, embedding_dim]) knowledge_lengths = tf.reshape(knowledge_lengths_ph, shape=[-1]) print(knowledge_tokens_embeddded) print(knowledge_lengths) knowledge_lstm_outputs = self.encoder.lstm_encoder( knowledge_tokens_embeddded, knowledge_lengths, "knowledge_lstm") knowledge_fw, knowledge_bw = sequence_feature(knowledge_lstm_outputs, knowledge_lengths, sep_pos=True) knowledge_concat_features = tf.concat([knowledge_fw, knowledge_bw], axis=-1) knowledge_concat_features = \ tf.reshape(knowledge_concat_features, shape=[batch_size, top_n, self.hparams.rnn_hidden_dim*2]) # [batch, top_n, 1536] return knowledge_concat_features
def _matching_aggregation_b_layer(self, m_text_b, text_b_len): """text_b_matching""" m_text_b_lstm_outputs = self.matching_encoder.lstm_encoder( m_text_b, text_b_len, name="text_b_matching") m_text_b_max = tf.reduce_max(m_text_b_lstm_outputs, axis=1) m_fw_text_b_state, m_bw_text_b_state = sequence_feature( m_text_b_lstm_outputs, text_b_len) return m_fw_text_b_state, m_bw_text_b_state
def dialog_representation(self, context_sentence_representation, tot_context_len): """context sentence outputs""" with tf.variable_scope("tot-context-sentence-layer"): tot_context_sentence_outputs = stack_bidirectional_rnn( cell="CUDNNGRU", num_layers=self.hparams.rnn_depth, num_units=self.hparams.sentence_rnn_hidden_dim * 2, inputs=context_sentence_representation, sequence_length=tot_context_len, state_merge="concat", output_dropout_keep_prob=self.dropout_keep_prob, residual=self.hparams.rnn_depth > 1) tot_context_fw_last_state, tot_context_bw_first_state = \ sequence_feature(tot_context_sentence_outputs, tot_context_len) tot_context_hidden = tf.concat( axis=-1, values=[tot_context_fw_last_state, tot_context_bw_first_state]) return tot_context_sentence_outputs, tot_context_hidden
def context_sentence_representation(self, context_sentence_embedded, context_sentence_len, tot_context_len, speaker): """context-sentence GLOVE(Avg) representation""" context_mask = tf.sequence_mask(context_sentence_len, maxlen=self.max_c_sentence_len) context_mask = tf.expand_dims(tf.cast(context_mask, tf.float32), axis=-1) masked_context_sentence = tf.multiply(context_sentence_embedded, context_mask) context_sentence_sum = tf.reduce_sum(masked_context_sentence, axis=2) context_mask = tf.squeeze(context_mask, [-1]) tot_context_mask = tf.cast( tf.sequence_mask(tot_context_len, maxlen=self.max_dialog_len), tf.float32) tot_context_len_tile = tf.tile(tf.expand_dims(tot_context_len, -1), [1, tf.shape(tot_context_mask)[-1]]) context_sentence_mean = \ tf.multiply(context_sentence_sum, tf.expand_dims(tf.divide(tot_context_mask, tf.cast(tot_context_len_tile, tf.float32)), -1)) """context sentence LSTM representation""" context_sentence_embedded = tf.reshape( context_sentence_embedded, [-1, self.max_c_sentence_len, self.hparams.embedding_dim]) context_sentence_len = tf.reshape(context_sentence_len, [-1]) with tf.variable_scope("context-sentence-encoder"): c_sentence_outputs = stack_bidirectional_rnn( cell="CUDNNGRU", num_layers=self.hparams.rnn_depth, num_units=self.hparams.sentence_rnn_hidden_dim * 2, inputs=context_sentence_embedded, sequence_length=context_sentence_len, state_merge="concat", output_dropout_keep_prob=self.dropout_keep_prob, residual=self.hparams.rnn_depth > 1) c_sentence_fw_last_state, c_sentence_bw_first_state = \ sequence_feature(c_sentence_outputs, context_sentence_len) c_sentence_hidden = tf.concat( axis=-1, values=[c_sentence_fw_last_state, c_sentence_bw_first_state]) c_sentence_hidden = tf.reshape(c_sentence_hidden, [ self.batch_size, self.max_dialog_len, self.hparams.sentence_rnn_hidden_dim * 2 ]) context_sentence_representation = tf.concat( axis=-1, values=[context_sentence_mean, c_sentence_hidden]) if self.pos_emb_bool or self.user_emb_bool: print("with sentence_features") context_w_sentence_feature = self._sentence_order_speaker_feature( context_sentence_representation, speaker) else: print("without sentence_features") context_w_sentence_feature = context_sentence_representation return context_w_sentence_feature
def build_graph(self): # dialog_cls : [batch, 768] # knowledge_bilstm_out : [batch, 5, 1536] bert_dialog_cls = self.bert_model.get_pooled_output() knowledge_bilstm_out = self._bert_pretrained_knowledge( self.knowledge_token_phs, self.length_phs[2]) knowledge_labels = self.label_id_phs[1] # knowledge_exist_len : [batch_size] knowledge_exist_len = tf.reduce_sum(knowledge_labels, axis=-1) # knowledge_exist_len = tf.Print(knowledge_exist_len, [knowledge_exist_len], message="knowledge_exist_len_sum", summarize=16) # knowledge_labels = tf.tile(tf.expand_dims(self.label_id_phs[1], axis=-1), multiples=[1, 1, tf.shape(knowledge_bilstm_out)[2]]) # knowledge_bilstm_out = tf.multiply(knowledge_bilstm_out, tf.cast(knowledge_labels,tf.float32)) # [batch, 5] [[1, 0, 0], [1, 1, 1], [1, 1, 0], [1, 0, 0], [1, 1, 1]] tiled_bert_dialog_cls = tf.tile( tf.expand_dims(bert_dialog_cls, axis=1), [1, self.hparams.top_n, 1]) dialog_knowledge_concat = tf.concat( [tiled_bert_dialog_cls, knowledge_bilstm_out], axis=-1) lstm_outputs = self.encoder.lstm_encoder(dialog_knowledge_concat, knowledge_exist_len, "dialog_cls_knowledge_lstm", rnn_hidden_dim=256) features_fw, features_bw = sequence_feature(lstm_outputs, knowledge_exist_len, sep_pos=False) lstm_hidden_outputs = tf.concat([features_fw, features_bw], axis=-1) filtered_lstm_hidden_outputs = tf.multiply( tf.cast(tf.expand_dims(knowledge_exist_len, axis=-1), tf.float32), lstm_hidden_outputs) # filtered_lstm_hidden_outputs = tf.Print(filtered_lstm_hidden_outputs, [filtered_lstm_hidden_outputs], message="lstm_hidden_outputs", summarize=512) # batch, 768 # concat_outputs = tf.concat(lstm_hidden_outputs, axis=-1) dialog_cls_projection = tf.layers.dense( name="dialog_cls_projection", inputs=bert_dialog_cls, units=512, kernel_initializer=create_initializer(initializer_range=0.02)) dialog_knowledge_cls_projection = tf.layers.dense( name="dialog_knowledge_cls_projection", inputs=lstm_hidden_outputs, units=512, activation=tf.nn.relu, kernel_initializer=create_initializer(initializer_range=0.02)) output_layer = tf.where( tf.equal(tf.reduce_sum(filtered_lstm_hidden_outputs, axis=-1), 0.), dialog_cls_projection, dialog_knowledge_cls_projection) self.test1 = tf.equal(output_layer, dialog_cls_projection) self.test2 = tf.equal(output_layer, dialog_knowledge_cls_projection) self.test_sum = tf.cast(self.test1, tf.int32) + tf.cast( self.test2, tf.int32) logits, loss_op = self._final_output_layer(output_layer) return logits, loss_op
def build_graph(self): # dialog_cls : [batch, 768] # knowledge_bilstm_out : [batch, 5, 1536] bert_seq_out = self.bert_model.get_sequence_output() dialog_bert_outputs, response_bert_outputs = \ self._bert_sentences_split(bert_seq_out, self.hparams.dialog_max_seq_length, self.hparams.response_max_seq_length) # bert_cls_out = self.bert_model.get_pooled_output() similar_input_ids_ph, similar_input_mask_ph, similar_len_ph = self.similar_phs # batch, top_n, max_seq_out, rnn_hidden_dim * 2 similar_dialogs_lstm_outputs = self._similar_dialog_lstm( similar_input_ids_ph, similar_len_ph) unstacked_similar_dialog_lstm_out = tf.unstack( similar_dialogs_lstm_outputs, self.hparams.top_n, axis=1) unstacked_similar_dialog_len = tf.unstack(similar_len_ph, self.hparams.top_n, axis=1) # response_bert_outputs : batch, 40, 768 # similar_dilaog_lstm_outputs : batch, top_n, 320, 512 dialog_len_ph, response_len_ph, _ = self.length_phs #[batch] response_len = response_len_ph - 1 response_lstm_outputs = self.encoder.lstm_encoder( response_bert_outputs, response_len, name="response_lstm") response_fw, response_bw = sequence_feature(response_lstm_outputs, response_len) response_concat = tf.concat([response_fw, response_bw], axis=-1) esim_att_out_l = [] for each_dialog_out, each_dialog_len in zip( unstacked_similar_dialog_lstm_out, unstacked_similar_dialog_len): # batch, 320, rnn_hidden_dim*2 -> each_dialog_out esim_att = ESIMAttention(self.hparams, self.hparams.dropout_keep_prob, text_a=response_lstm_outputs, text_a_len=response_len, text_b=each_dialog_out, text_b_len=each_dialog_len) # batch, rnn_hidden_dim * 2 : 256 * 2 = 512 esim_att_out_l.append(esim_att.text_a_att_outs) # batch, rnn_hidden_dim * 2 : (total top_n : 3) mlp_layers = [] for each_att_out in esim_att_out_l: concat_features = tf.concat([response_concat, each_att_out], axis=-1) layer_input = concat_features for i in range(3): dense_out = tf.layers.dense( inputs=layer_input, units=768, activation=tf.nn.relu, kernel_initializer=create_initializer(0.02), name="mlp_%d" % i) dense_out = tf.nn.dropout(dense_out, self.hparams.dropout_keep_prob) layer_input = dense_out mlp_layers.append(layer_input) # element-wise summation output_layer = tf.add_n(mlp_layers, "mlp_layers_add_n") logits, loss_op = self._final_output_layer(output_layer) return logits, loss_op