def __init__(self,
                 hparams,
                 bert_config=None,
                 bert_model: BertModel = None,
                 input_phs=None,
                 label_id_phs=None,
                 length_phs=None,
                 knowledge_phs=None,
                 similar_phs=None):
        self.hparams = hparams
        self.bert_hidden_dropout_prob = 0.1
        if self.hparams.do_evaluate:
            self.bert_hidden_dropout_prob = 0.0
        self.bert_config = bert_config
        self.bert_model = bert_model
        self.input_phs = input_phs
        self.label_id_phs = label_id_phs
        self.length_phs = length_phs
        self.knowledge_token_phs = knowledge_phs
        self.similar_phs = similar_phs

        self._get_pretrained_variables()

        self.encoder = BasicEncoder(self.hparams,
                                    self.hparams.dropout_keep_prob)
Ejemplo n.º 2
0
    def __init__(self,
                 hparams,
                 bert_config=None,
                 bert_model=None,
                 input_phs=None,
                 label_id_phs=None,
                 length_phs=None,
                 knowledge_phs=None,
                 similar_phs=None):
        self.hparams = hparams
        self.bert_config = bert_config
        self.bert_model = bert_model
        self.label_id_phs = label_id_phs
        self.length_phs = length_phs
        self.knowledge_token_phs = knowledge_phs

        self.encoder = BasicEncoder(self.hparams,
                                    self.hparams.dropout_keep_prob)
Ejemplo n.º 3
0
    def __init__(self,
                 hparams,
                 dropout_keep_prob,
                 text_a,
                 text_a_len,
                 text_b,
                 text_b_len,
                 only_text_a=True):
        self.hparams = hparams
        self.dropout_keep_prob = dropout_keep_prob

        self.matching_encoder = BasicEncoder(self.hparams,
                                             self.dropout_keep_prob)
        m_text_a, m_text_b = self._attention_matching_layer(text_a, text_b)

        m_fw_text_a_state, m_bw_text_a_state = self._matching_aggregation_a_layer(
            m_text_a, text_a_len)
        self.text_a_att_outs = tf.concat(
            [m_fw_text_a_state, m_bw_text_a_state], axis=-1)
Ejemplo n.º 4
0
class ESIMAttention(object):
    def __init__(self,
                 hparams,
                 dropout_keep_prob,
                 text_a,
                 text_a_len,
                 text_b,
                 text_b_len,
                 only_text_a=True):
        self.hparams = hparams
        self.dropout_keep_prob = dropout_keep_prob

        self.matching_encoder = BasicEncoder(self.hparams,
                                             self.dropout_keep_prob)
        m_text_a, m_text_b = self._attention_matching_layer(text_a, text_b)

        m_fw_text_a_state, m_bw_text_a_state = self._matching_aggregation_a_layer(
            m_text_a, text_a_len)
        self.text_a_att_outs = tf.concat(
            [m_fw_text_a_state, m_bw_text_a_state], axis=-1)

    def _attention_text_b(self, similarity_matrix, text_a):
        """
		:param similarity_matrix: [batch_size, max_text_b_len, max_text_a_len]
		:param text_a: [batch_size, max_text_a_len, hidden_dim]
		:return:
		"""
        attention_weight_text_a = \
         tf.where(tf.equal(similarity_matrix, 0.), similarity_matrix, tf.nn.softmax(similarity_matrix))

        attended_text_b = tf.matmul(attention_weight_text_a, text_a)

        return attended_text_b

    def _attention_text_a(self, similarity_matrix, text_b):
        """
		:param similarity_matrix: [batch_size, max_text_b_len, max_text_a_len]
		:param text_b: [batch_size, max_text_b_len, hidden_dim]
		:return: attend_text_a
		"""
        sim_trans_mat = tf.transpose(similarity_matrix, perm=[0, 2, 1])
        attention_weight_text_b = \
         tf.where(tf.equal(sim_trans_mat, 0.), sim_trans_mat, tf.nn.softmax(sim_trans_mat))

        attended_text_a = tf.matmul(attention_weight_text_b, text_b)

        return attended_text_a

    def _similarity_matrix(self, text_a, text_b):
        """
		Dot attention : text_a, text_b bert
		:param text_a: [batch, max_text_a_len, 768]
		:param text_b: [batch, max_text_b_len, 768]
		:return: similarity_matrix #[batch, max_text_b_len, max_text_a_len]
		"""
        similarity_matrix = tf.matmul(text_b,
                                      tf.transpose(text_a, perm=[0, 2, 1]))

        return similarity_matrix

    def _attention_matching_layer(self, text_a, text_b):
        similarity = self._similarity_matrix(text_a, text_b)
        # shape: [batch, max_text_a_len, 768]
        attended_text_a = self._attention_text_a(similarity, text_b)
        # shape: [batch, max_text_b_len, 768]
        attended_text_b = self._attention_text_b(similarity, text_a)

        m_text_a = tf.concat(axis=-1,
                             values=[
                                 text_a, attended_text_a,
                                 text_a - attended_text_a,
                                 tf.multiply(text_a, attended_text_a)
                             ])

        m_text_b = tf.concat(axis=-1,
                             values=[
                                 text_b, attended_text_b,
                                 text_b - attended_text_b,
                                 tf.multiply(text_b, attended_text_b)
                             ])

        return m_text_a, m_text_b

    def _matching_aggregation_a_layer(self, m_text_a, text_a_len):
        """text_a_matching"""
        m_text_a_lstm_outputs = self.matching_encoder.lstm_encoder(
            m_text_a, text_a_len, name="text_a_matching")
        m_text_a_max = tf.reduce_max(m_text_a_lstm_outputs, axis=1)
        m_fw_text_a_state, m_bw_text_a_state = sequence_feature(
            m_text_a_lstm_outputs, text_a_len)

        return m_fw_text_a_state, m_bw_text_a_state

    def _matching_aggregation_b_layer(self, m_text_b, text_b_len):
        """text_b_matching"""
        m_text_b_lstm_outputs = self.matching_encoder.lstm_encoder(
            m_text_b, text_b_len, name="text_b_matching")
        m_text_b_max = tf.reduce_max(m_text_b_lstm_outputs, axis=1)
        m_fw_text_b_state, m_bw_text_b_state = sequence_feature(
            m_text_b_lstm_outputs, text_b_len)

        return m_fw_text_b_state, m_bw_text_b_state
Ejemplo n.º 5
0
class Model(object):
    def __init__(self,
                 hparams,
                 bert_config=None,
                 bert_model=None,
                 input_phs=None,
                 label_id_phs=None,
                 length_phs=None,
                 knowledge_phs=None,
                 similar_phs=None):
        self.hparams = hparams
        self.bert_config = bert_config
        self.bert_model = bert_model
        self.label_id_phs = label_id_phs
        self.length_phs = length_phs
        self.knowledge_token_phs = knowledge_phs

        self.encoder = BasicEncoder(self.hparams,
                                    self.hparams.dropout_keep_prob)

        # dialog_label_ids, knowledge_label_ids = label_id_phs
        # knowledge_ids_ph, knowledge_mask_ph, knowledge_seg_ids_ph = knowledge_phs
        # dialog_len_ph, response_len_ph, knowledge_len_ph = length_phs

    def build_graph(self):
        # dialog_cls : [batch, 768]
        # knowledge_bilstm_out : [batch, 5, 1536]
        bert_dialog_cls = self.bert_model.get_pooled_output()
        knowledge_bilstm_out = self._bert_pretrained_knowledge(
            self.knowledge_token_phs, self.length_phs[2])

        knowledge_labels = self.label_id_phs[1]
        # knowledge_exist_len : [batch_size]
        knowledge_exist_len = tf.reduce_sum(knowledge_labels, axis=-1)
        # knowledge_exist_len = tf.Print(knowledge_exist_len, [knowledge_exist_len], message="knowledge_exist_len_sum", summarize=16)

        # knowledge_labels = tf.tile(tf.expand_dims(self.label_id_phs[1], axis=-1), multiples=[1, 1, tf.shape(knowledge_bilstm_out)[2]])
        # knowledge_bilstm_out = tf.multiply(knowledge_bilstm_out, tf.cast(knowledge_labels,tf.float32))

        # [batch, 5] [[1, 0, 0], [1, 1, 1], [1, 1, 0], [1, 0, 0], [1, 1, 1]]
        tiled_bert_dialog_cls = tf.tile(
            tf.expand_dims(bert_dialog_cls, axis=1),
            [1, self.hparams.top_n, 1])
        dialog_knowledge_concat = tf.concat(
            [tiled_bert_dialog_cls, knowledge_bilstm_out], axis=-1)

        lstm_outputs = self.encoder.lstm_encoder(dialog_knowledge_concat,
                                                 knowledge_exist_len,
                                                 "dialog_cls_knowledge_lstm",
                                                 rnn_hidden_dim=256)
        features_fw, features_bw = sequence_feature(lstm_outputs,
                                                    knowledge_exist_len,
                                                    sep_pos=False)

        lstm_hidden_outputs = tf.concat([features_fw, features_bw], axis=-1)
        filtered_lstm_hidden_outputs = tf.multiply(
            tf.cast(tf.expand_dims(knowledge_exist_len, axis=-1), tf.float32),
            lstm_hidden_outputs)
        # filtered_lstm_hidden_outputs = tf.Print(filtered_lstm_hidden_outputs, [filtered_lstm_hidden_outputs], message="lstm_hidden_outputs", summarize=512)

        # batch, 768
        # concat_outputs = tf.concat(lstm_hidden_outputs, axis=-1)
        dialog_cls_projection = tf.layers.dense(
            name="dialog_cls_projection",
            inputs=bert_dialog_cls,
            units=512,
            kernel_initializer=create_initializer(initializer_range=0.02))
        dialog_knowledge_cls_projection = tf.layers.dense(
            name="dialog_knowledge_cls_projection",
            inputs=lstm_hidden_outputs,
            units=512,
            activation=tf.nn.relu,
            kernel_initializer=create_initializer(initializer_range=0.02))

        output_layer = tf.where(
            tf.equal(tf.reduce_sum(filtered_lstm_hidden_outputs, axis=-1), 0.),
            dialog_cls_projection, dialog_knowledge_cls_projection)

        self.test1 = tf.equal(output_layer, dialog_cls_projection)
        self.test2 = tf.equal(output_layer, dialog_knowledge_cls_projection)
        self.test_sum = tf.cast(self.test1, tf.int32) + tf.cast(
            self.test2, tf.int32)

        logits, loss_op = self._final_output_layer(output_layer)

        return logits, loss_op

    def _bert_pretrained_knowledge(self, knowledge_tokens_ph,
                                   knowledge_lengths_ph):
        # knowledge_phs : [batch, top_n, max_seq_len, 768]
        input_shape = get_shape_list(knowledge_tokens_ph, expected_rank=4)
        batch_size = input_shape[0]
        top_n = input_shape[1]
        knowledge_max_seq_len = input_shape[2]
        embedding_dim = input_shape[3]

        print(input_shape)
        knowledge_tokens_embeddded = tf.reshape(
            knowledge_tokens_ph,
            shape=[-1, knowledge_max_seq_len, embedding_dim])
        knowledge_lengths = tf.reshape(knowledge_lengths_ph, shape=[-1])

        print(knowledge_tokens_embeddded)
        print(knowledge_lengths)
        knowledge_lstm_outputs = self.encoder.lstm_encoder(
            knowledge_tokens_embeddded, knowledge_lengths, "knowledge_lstm")
        knowledge_fw, knowledge_bw = sequence_feature(knowledge_lstm_outputs,
                                                      knowledge_lengths,
                                                      sep_pos=True)
        knowledge_concat_features = tf.concat([knowledge_fw, knowledge_bw],
                                              axis=-1)
        knowledge_concat_features = \
          tf.reshape(knowledge_concat_features, shape=[batch_size, top_n, self.hparams.rnn_hidden_dim*2])
        # [batch, top_n, 1536]

        return knowledge_concat_features

    def _final_output_layer(self, final_input_layer):

        dialog_label_ids, knowledge_label_ids = self.label_id_phs
        if self.hparams.loss_type == "sigmoid": logits_units = 1
        else: logits_units = 2

        logits = tf.layers.dense(
            inputs=final_input_layer,
            units=logits_units,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="logits")

        if self.hparams.loss_type == "sigmoid":
            logits = tf.squeeze(logits, axis=-1)
            loss_op = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits,
                labels=tf.cast(dialog_label_ids, tf.float32),
                name="binary_cross_entropy")
        else:
            loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=dialog_label_ids, name="cross_entropy")

        loss_op = tf.reduce_mean(loss_op, name="cross_entropy_mean")

        return logits, loss_op
class Model(object):
    def __init__(self,
                 hparams,
                 bert_config=None,
                 bert_model: BertModel = None,
                 label_id_phs=None,
                 input_phs=None,
                 length_phs=None,
                 knowledge_phs=None,
                 similar_phs=None):

        self.hparams = hparams
        self.bert_config = bert_config
        self.bert_model = bert_model
        self.label_id_phs = label_id_phs
        self.length_phs = length_phs
        self.knowledge_token_phs = knowledge_phs
        self.similar_phs = similar_phs

        self._get_pretrained_variables()

        self.encoder = BasicEncoder(self.hparams,
                                    self.hparams.dropout_keep_prob)

        # dialog_label_ids, knowledge_label_ids = label_id_phs
        # knowledge_ids_ph, knowledge_mask_ph, knowledge_seg_ids_ph = knowledge_phs
        # dialog_len_ph, response_len_ph, knowledge_len_ph = length_phs
    def _get_pretrained_variables(self):
        self.bert_pretrained_word_embeddings = self.bert_model.embedding_table

    def _bert_sentences_split(self, bert_sequence_output, max_seq_len_a,
                              max_seq_len_b):
        dialog_bert_outputs, response_bert_outputs = tf.split(
            bert_sequence_output, [max_seq_len_a, max_seq_len_b], axis=1)

        return dialog_bert_outputs, response_bert_outputs

    def _similar_dialog_lstm(self, similar_dialog_input_ph,
                             similar_dialog_len_ph):
        """
		:param similar_dialog_input_ph: [batch, top_n, max_sequence_len]
		:param similar_dialog_len_ph: [batch, top_n]
		:return:
		"""
        input_shape = get_shape_list(similar_dialog_input_ph, expected_rank=3)
        batch_size = input_shape[0]
        top_n = input_shape[1]
        max_seq_len = input_shape[2]

        # batch, top_n, max_seq_len, 768
        similar_dialog_embedded = tf.nn.embedding_lookup(
            self.bert_pretrained_word_embeddings, similar_dialog_input_ph)
        # reshape inputs, length
        similar_dialog_embedded = tf.reshape(
            similar_dialog_embedded,
            shape=[-1, max_seq_len, self.hparams.embedding_dim])
        similar_dialog_len = tf.reshape(similar_dialog_len_ph, shape=[-1])
        similar_dialog_out = self.encoder.lstm_encoder(similar_dialog_embedded,
                                                       similar_dialog_len,
                                                       "similar_dialog_lstm")
        similar_dialog_out = tf.reshape(similar_dialog_out,
                                        shape=[
                                            batch_size, top_n, max_seq_len,
                                            self.hparams.rnn_hidden_dim * 2
                                        ])

        return similar_dialog_out

    def build_graph(self):
        # dialog_cls : [batch, 768]
        # knowledge_bilstm_out : [batch, 5, 1536]
        bert_seq_out = self.bert_model.get_sequence_output()
        dialog_bert_outputs, response_bert_outputs = \
          self._bert_sentences_split(bert_seq_out, self.hparams.dialog_max_seq_length, self.hparams.response_max_seq_length)
        # bert_cls_out = self.bert_model.get_pooled_output()

        similar_input_ids_ph, similar_input_mask_ph, similar_len_ph = self.similar_phs
        # batch, top_n, max_seq_out, rnn_hidden_dim * 2
        similar_dialogs_lstm_outputs = self._similar_dialog_lstm(
            similar_input_ids_ph, similar_len_ph)
        unstacked_similar_dialog_lstm_out = tf.unstack(
            similar_dialogs_lstm_outputs, self.hparams.top_n, axis=1)
        unstacked_similar_dialog_len = tf.unstack(similar_len_ph,
                                                  self.hparams.top_n,
                                                  axis=1)

        # response_bert_outputs : batch, 40, 768
        # similar_dilaog_lstm_outputs : batch, top_n, 320, 512
        dialog_len_ph, response_len_ph, _ = self.length_phs  #[batch]
        response_len = response_len_ph - 1
        response_lstm_outputs = self.encoder.lstm_encoder(
            response_bert_outputs, response_len, name="response_lstm")
        response_fw, response_bw = sequence_feature(response_lstm_outputs,
                                                    response_len)
        response_concat = tf.concat([response_fw, response_bw], axis=-1)
        esim_att_out_l = []
        for each_dialog_out, each_dialog_len in zip(
                unstacked_similar_dialog_lstm_out,
                unstacked_similar_dialog_len):
            # batch, 320, rnn_hidden_dim*2 -> each_dialog_out
            esim_att = ESIMAttention(self.hparams,
                                     self.hparams.dropout_keep_prob,
                                     text_a=response_lstm_outputs,
                                     text_a_len=response_len,
                                     text_b=each_dialog_out,
                                     text_b_len=each_dialog_len)
            # batch, rnn_hidden_dim * 2 : 256 * 2 = 512
            esim_att_out_l.append(esim_att.text_a_att_outs)

        # batch, rnn_hidden_dim * 2 : (total top_n : 3)
        mlp_layers = []
        for each_att_out in esim_att_out_l:
            concat_features = tf.concat([response_concat, each_att_out],
                                        axis=-1)
            layer_input = concat_features
            for i in range(3):
                dense_out = tf.layers.dense(
                    inputs=layer_input,
                    units=768,
                    activation=tf.nn.relu,
                    kernel_initializer=create_initializer(0.02),
                    name="mlp_%d" % i)
                dense_out = tf.nn.dropout(dense_out,
                                          self.hparams.dropout_keep_prob)
                layer_input = dense_out
            mlp_layers.append(layer_input)
        # element-wise summation
        output_layer = tf.add_n(mlp_layers, "mlp_layers_add_n")

        logits, loss_op = self._final_output_layer(output_layer)

        return logits, loss_op

    def _final_output_layer(self, final_input_layer):

        dialog_label_ids, knowledge_label_ids = self.label_id_phs
        if self.hparams.loss_type == "sigmoid": logits_units = 1
        else: logits_units = 2

        logits = tf.layers.dense(
            inputs=final_input_layer,
            units=logits_units,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="logits")

        if self.hparams.loss_type == "sigmoid":
            logits = tf.squeeze(logits, axis=-1)
            loss_op = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits,
                labels=tf.cast(dialog_label_ids, tf.float32),
                name="binary_cross_entropy")
        else:
            loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=dialog_label_ids, name="cross_entropy")

        loss_op = tf.reduce_mean(loss_op, name="cross_entropy_mean")

        return logits, loss_op