def call(self, input_text, labels=None, training=None): text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(input_text, 0), dtype=tf.int32), axis=-1) mask = tf.math.logical_not(tf.math.equal(input_text, 0)) inputs = self.embedding(input_text) inputs = self.bi_lstm(inputs, mask=mask) inputs = self.dropout(inputs, training) logits = self.dense(inputs) if labels is not None: label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32) log_likelihood, self.transition_params = crf_log_likelihood(logits, label_sequences, text_lens,self.transition_params) return logits, text_lens, log_likelihood else: return logits, text_lens
def call(self, inputs, training=None, mask=None, labels=None): # seg_id = tf.zeros(mask.shape) text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(inputs, 0), dtype=tf.int32), axis=-1) outputs = self.bert_model(inputs, attention_mask=mask) outputs = outputs[0] out_tags = self.out(outputs) if labels is not None: label_sequences = tf.convert_to_tensor(labels, dtype=tf.int32) log_likelihood, self.transition_params = crf_log_likelihood( out_tags, label_sequences, text_lens, self.transition_params) return out_tags, text_lens, log_likelihood else: return out_tags, text_lens
def call(self, char_ids, word_ids, entity_ids=None, data_max_len=None, training=None, mask=None): mask_value = tf.math.logical_not(tf.math.equal(char_ids, 0)) text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(char_ids, 0), dtype=tf.int32), axis=-1) char_embed = self.char_embed(char_ids) word_embed = self.word_embed(word_ids) embed = tf.concat([char_embed, word_embed], axis=-1) sent_encoder = self.bi_lstm(embed, mask=mask_value) # eimission = self.emission(sent_encoder) entity_logits = self.entity_classifier(sent_encoder) if training: ent_encoder = self.ent_embed(entity_ids) label_sequences = tf.convert_to_tensor(entity_ids, dtype=tf.int32) log_likelihood, self.transition_params = crf_log_likelihood(entity_logits, label_sequences, text_lens, self.transition_params) else: v_entity_ids = [] for logit, text_len in zip(entity_logits, text_lens): viterbi_path, _ = viterbi_decode(logit[:text_len], self.transition_params) v_entity_ids.append(viterbi_path) entity_ids = tf.keras.preprocessing.sequence.pad_sequences(v_entity_ids, padding='post', dtype=tf.int32) ent_encoder = self.ent_embed(entity_ids) rel_encoder = tf.concat((sent_encoder, ent_encoder), axis=-1) B, L, H = rel_encoder.shape if L is None: L = data_max_len u = tf.expand_dims(self.selection_u(rel_encoder), axis=1) u = tf.keras.activations.relu(tf.tile(u, multiples=(1, L, 1, 1))) v = tf.expand_dims(self.selection_v(rel_encoder), axis=2) v = tf.keras.activations.relu(tf.tile(v, multiples=(1, 1, L, 1))) uv = self.selection_uv(tf.concat((u, v), axis=-1)) # print(self.rel_embed.get_weights()) rel_logits = self.rel_classifier(uv) # selection_logits = tf.einsum('bijh,rh->birj', uv, self.rel_embed.get_weights()[0]) # return entity_logits, rel_logits, mask_value