Example #1
0
    def call(self, input_text, labels=None,  training=None):
        text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(input_text, 0), dtype=tf.int32), axis=-1)
        mask = tf.math.logical_not(tf.math.equal(input_text, 0))
        inputs = self.embedding(input_text)
        inputs = self.bi_lstm(inputs, mask=mask)
        inputs = self.dropout(inputs, training)
        logits = self.dense(inputs)

        if labels is not None:
            label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32)
            log_likelihood, self.transition_params = crf_log_likelihood(logits, label_sequences, text_lens,self.transition_params)

            return logits, text_lens, log_likelihood
        else:
            return logits, text_lens
Example #2
0
    def call(self, inputs, training=None, mask=None, labels=None):
        # seg_id = tf.zeros(mask.shape)
        text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(inputs, 0),
                                               dtype=tf.int32),
                                       axis=-1)
        outputs = self.bert_model(inputs, attention_mask=mask)
        outputs = outputs[0]

        out_tags = self.out(outputs)

        if labels is not None:
            label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32)
            log_likelihood, self.transition_params = crf_log_likelihood(
                out_tags, label_sequences, text_lens, self.transition_params)

            return out_tags, text_lens, log_likelihood
        else:
            return out_tags, text_lens
Example #3
0
    def call(self, char_ids, word_ids, entity_ids=None, data_max_len=None, training=None, mask=None):
        mask_value = tf.math.logical_not(tf.math.equal(char_ids, 0))
        text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(char_ids, 0), dtype=tf.int32), axis=-1)
        char_embed = self.char_embed(char_ids)
        word_embed = self.word_embed(word_ids)

        embed = tf.concat([char_embed, word_embed], axis=-1)
        sent_encoder = self.bi_lstm(embed, mask=mask_value)
        # eimission = self.emission(sent_encoder)
        entity_logits = self.entity_classifier(sent_encoder)
        if training:
            ent_encoder = self.ent_embed(entity_ids)
            label_sequences = tf.convert_to_tensor(entity_ids, dtype=tf.int32)
            log_likelihood, self.transition_params = crf_log_likelihood(entity_logits, label_sequences, text_lens,
                                                                        self.transition_params)
        else:
            v_entity_ids = []
            for logit, text_len in zip(entity_logits, text_lens):
                viterbi_path, _ = viterbi_decode(logit[:text_len], self.transition_params)
                v_entity_ids.append(viterbi_path)
            entity_ids = tf.keras.preprocessing.sequence.pad_sequences(v_entity_ids, padding='post', dtype=tf.int32)
            ent_encoder = self.ent_embed(entity_ids)

        rel_encoder = tf.concat((sent_encoder, ent_encoder), axis=-1)
        B, L, H = rel_encoder.shape
        if L is None:
            L = data_max_len
        u = tf.expand_dims(self.selection_u(rel_encoder), axis=1)
        u = tf.keras.activations.relu(tf.tile(u, multiples=(1, L, 1, 1)))
        v = tf.expand_dims(self.selection_v(rel_encoder), axis=2)
        v = tf.keras.activations.relu(tf.tile(v, multiples=(1, 1, L, 1)))
        uv = self.selection_uv(tf.concat((u, v), axis=-1))
        # print(self.rel_embed.get_weights())
        rel_logits = self.rel_classifier(uv)
        # selection_logits = tf.einsum('bijh,rh->birj', uv, self.rel_embed.get_weights()[0])
        #
        return entity_logits, rel_logits, mask_value