Ejemplo n.º 1
0
    def __call__(self,
                 input_ids,
                 start_labels,
                 end_labels,
                 token_type_ids_list,
                 query_len_list,
                 text_length_list,
                 has_answer_label,
                 is_training,
                 is_testing=False):
        bert_model = modeling.BertModel(config=self.bert_config,
                                        is_training=is_training,
                                        input_ids=input_ids,
                                        text_length=text_length_list,
                                        token_type_ids=token_type_ids_list,
                                        use_one_hot_embeddings=False)
        bert_seq_output = bert_model.get_sequence_output()
        first_seq_hidden = bert_model.get_pooled_output()
        # bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu)
        # bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training)
        start_logits = tf.layers.dense(bert_seq_output, self.num_labels)
        end_logits = tf.layers.dense(bert_seq_output, self.num_labels)
        query_span_mask = tf.cast(tf.sequence_mask(query_len_list), tf.int32)
        total_seq_mask = tf.cast(tf.sequence_mask(text_length_list), tf.int32)
        query_span_mask = query_span_mask * -1
        query_len_max = tf.shape(query_span_mask)[1]
        left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max
        zero_mask_left_span = tf.zeros(
            (tf.shape(query_span_mask)[0], left_query_len_max), dtype=tf.int32)
        final_mask = tf.concat((query_span_mask, zero_mask_left_span), axis=-1)
        final_mask = final_mask + total_seq_mask
        predict_start_ids = tf.argmax(start_logits,
                                      axis=-1,
                                      name="pred_start_ids")
        predict_start_prob = tf.nn.softmax(start_logits, axis=-1)
        predict_end_prob = tf.nn.softmax(end_logits, axis=-1)
        predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids")
        # has_answer_logits = tf.layers.dropout(first_seq_hidden,rate=self.dropout_rate,training=is_training)
        has_answer_logits = tf.layers.dense(first_seq_hidden, 1)
        predict_has_answer_probs = tf.nn.sigmoid(has_answer_logits)
        if not is_testing:
            # one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32)
            # start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True)
            # end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True)

            # focal loss
            start_loss = focal_loss(start_logits, start_labels, final_mask,
                                    self.num_labels, True, 1.8)
            end_loss = focal_loss(end_logits, end_labels, final_mask,
                                  self.num_labels, True, 1.8)
            has_answer_label = tf.cast(has_answer_label, tf.float32)
            per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=has_answer_label, logits=has_answer_logits)
            has_answer_loss = tf.reduce_mean(per_example_loss)
            # has_answer_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_labels,logits=has_answer_logits))
            final_loss = (1.5 * start_loss + end_loss + has_answer_loss) / 3.0
            return final_loss, predict_start_ids, predict_end_ids, final_mask, predict_start_prob, predict_end_prob, predict_has_answer_probs
        else:
            return predict_start_ids, predict_end_ids, final_mask, predict_start_prob, predict_end_prob, predict_has_answer_probs
Ejemplo n.º 2
0
    def __call__(self,
                 input_ids,
                 start_labels,
                 end_labels,
                 token_type_ids_list,
                 query_len_list,
                 text_length_list,
                 is_training,
                 is_testing=False):
        bert_model = modeling.BertModel(config=self.bert_config,
                                        is_training=is_training,
                                        input_ids=input_ids,
                                        text_length=text_length_list,
                                        token_type_ids=token_type_ids_list,
                                        use_one_hot_embeddings=False)
        bert_seq_output = bert_model.get_sequence_output()

        # bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu)
        # bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training)
        start_logits = tf.layers.dense(bert_seq_output, self.num_labels)
        end_logits = tf.layers.dense(bert_seq_output, self.num_labels)
        query_span_mask = tf.cast(tf.sequence_mask(query_len_list), tf.int32)
        total_seq_mask = tf.cast(tf.sequence_mask(text_length_list), tf.int32)
        query_span_mask = query_span_mask * -1
        query_len_max = tf.shape(query_span_mask)[1]
        left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max
        zero_mask_left_span = tf.zeros(
            (tf.shape(query_span_mask)[0], left_query_len_max), dtype=tf.int32)
        final_mask = tf.concat((query_span_mask, zero_mask_left_span), axis=-1)
        final_mask = final_mask + total_seq_mask
        predict_start_ids = tf.argmax(start_logits,
                                      axis=-1,
                                      name="pred_start_ids")
        predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids")
        if not is_testing:
            # one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32)
            # start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True)
            # end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True)

            # focal loss
            start_loss = focal_loss(start_logits, start_labels, final_mask,
                                    self.num_labels, True)
            end_loss = focal_loss(end_logits, end_labels, final_mask,
                                  self.num_labels, True)

            final_loss = start_loss + end_loss
            return final_loss, predict_start_ids, predict_end_ids, final_mask
        else:
            return predict_start_ids, predict_end_ids, final_mask
Ejemplo n.º 3
0
    def _build_loss(self, **kwargs):
        r_alpha = kwargs.pop('r_alpha', 1)
        with tf.variable_scope('losses'):
            conf_loss = focal_loss(self.logits, self.y)
            regress_loss = smooth_l1_loss(self.logits, self.y)
            total_loss = conf_loss + r_alpha * regress_loss

        # for debug
        self.conf_loss = conf_loss
        self.regress_loss = regress_loss
        return total_loss