def batch_gather(self, emb, indices): batch_size = util_tf2.shape(emb, 0) seqlen = util_tf2.shape(emb, 1) if len(emb.get_shape()) > 2: emb_size = util_tf2.shape(emb, 2) else: emb_size = 1 flattened_emb = tf.reshape(emb, [batch_size * seqlen, emb_size]) # [batch_size * seqlen, emb] offset = tf.expand_dims(tf.range(batch_size) * seqlen, 1) # [batch_size, 1] gathered = tf.gather(flattened_emb, indices + offset) # [batch_size, num_indices, emb] if len(emb.get_shape()) == 2: gathered = tf.squeeze(gathered, 2) # [batch_size, num_indices] return gathered
def cnn(self, inputs, filter_sizes, num_filters): """ concatenate 3 conv1d layer output in_channel, out_channel """ num_words = util_tf2.shape(inputs, 0) num_chars = util_tf2.shape(inputs, 1) input_size = util_tf2.shape(inputs, 2) outputs = [] for i, filter_size in enumerate(filter_sizes): with tf.compat.v1.variable_scope("conv_{}".format(i)): w = tf.compat.v1.get_variable("w", [filter_size, input_size, num_filters]) b = tf.compat.v1.get_variable("b", [num_filters]) conv = tf.nn.conv1d(input=inputs, filters=w, stride=1, padding="VALID") # [num_words, num_chars - filter_size, num_filters] h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters] pooled = tf.reduce_max(input_tensor=h, axis=1) # [num_words, num_filters] outputs.append(pooled) return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)]
def flatten_emb_by_sentence(self, emb, text_len_mask): num_sentences = tf.shape(input=emb)[0] max_sentence_length = tf.shape(input=emb)[1] emb_rank = len(emb.get_shape()) if emb_rank == 2: flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length]) elif emb_rank == 3: flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, util_tf2.shape(emb, 2)]) else: raise ValueError("Unsupported rank: {}".format(emb_rank)) return tf.boolean_mask(tensor=flattened_emb, mask=tf.reshape(text_len_mask, [num_sentences * max_sentence_length]))
def distance_pruning(self, top_span_emb, top_span_mention_scores, c): """ Args: top_span_emb: [k, emb], top_span_mention_scores: [k] c: top number """ k = util_tf2.shape(top_span_emb, 0) top_antecedent_offsets = tf.tile(tf.expand_dims(tf.range(c) + 1, 0), [k, 1]) # [k, c] raw_top_antecedents = tf.expand_dims(tf.range(k), 1) - top_antecedent_offsets # [k, c] top_antecedents_mask = raw_top_antecedents >= 0 # [k, c] top_antecedents = tf.maximum(raw_top_antecedents, 0) # [k, c] top_fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.gather(top_span_mention_scores, top_antecedents) # [k, c] top_fast_antecedent_scores += tf.math.log(tf.cast(top_antecedents_mask, dtype=tf.float32)) # [k, c] return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets
def lstm_contextualize(self, text_emb, text_len, text_len_mask): """ bi-directional lstm nn Args: text_emb: [num_sentences, max_sentence_length, emb] text_len: [num_sentences], length of every sentence text_len_mask: [num_sentence, max_sentence_length] Returns: """ num_sentences = tf.shape(input=text_emb)[0] current_inputs = text_emb # [num_sentences, max_sentence_length, emb] for layer in range(self.config["contextualization_layers"]): with tf.compat.v1.variable_scope("layer_{}".format(layer)): with tf.compat.v1.variable_scope("fw_cell"): cell_fw = util_tf2.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout) with tf.compat.v1.variable_scope("bw_cell"): cell_bw = util_tf2.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout) #state_fw = tf.nn.rnn_cell.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), \ # tf.tile(cell_fw.initial_state.h, [num_sentences, 1])) #state_bw = tf.nn.rnn_cell.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), \ # tf.tile(cell_bw.initial_state.h, [num_sentences, 1])) state_fw = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), \ tf.tile(cell_fw.initial_state.h, [num_sentences, 1])) state_bw = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), \ tf.tile(cell_bw.initial_state.h, [num_sentences, 1])) (fw_outputs, bw_outputs), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=current_inputs, sequence_length=text_len, initial_state_fw=state_fw, initial_state_bw=state_bw ) text_outputs = tf.concat([fw_outputs, bw_outputs], 2) # [num_sentences, max_sentence_length, emb] text_outputs = tf.nn.dropout(text_outputs, 1 - (self.lstm_dropout)) if layer > 0: highway_gates = tf.sigmoid(util_tf2.projection(text_outputs, util_tf2.shape(text_outputs, 2))) # [num_sentences, max_sentence_length, emb] text_outputs = highway_gates * text_outputs + (1 - highway_gates) * current_inputs current_inputs = text_outputs return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
def get_span_emb(self, head_emb, context_outputs, span_starts, span_ends): """ Args: head_emb: [num_words, emb], context_outputs: [num_words, emb], span_starts: [num_candidates], word index span_ends: [num_candidates], word index Returns: span_emb: [num_candidates, emb], span embedding """ span_emb_list = [] span_start_emb = tf.gather(context_outputs, span_starts) # [num_candidates, emb] span_emb_list.append(span_start_emb) span_end_emb = tf.gather(context_outputs, span_ends) # [num_candidates, emb] span_emb_list.append(span_end_emb) span_width = 1 + span_ends - span_starts # [num_candidates], with of every span if self.config["use_features"]: # span length feature span_width_index = span_width - 1 # [num_candidates] span_width_emb = tf.gather(tf.compat.v1.get_variable("span_width_embeddings", \ [self.config["max_span_width"], self.config["feature_size"]]), span_width_index) # [num_candidates, emb] span_width_emb = tf.nn.dropout(span_width_emb, 1 - (self.dropout)) span_emb_list.append(span_width_emb) if self.config["model_heads"]: span_indices = tf.expand_dims(tf.range(self.config["max_span_width"]), 0) + \ tf.expand_dims(span_starts, 1) # [num_candidates, max_span_width] span_indices = tf.minimum(util_tf2.shape(context_outputs, 0) - 1, span_indices) # [num_candidates, max_span_width] span_text_emb = tf.gather(head_emb, span_indices) # [num_candidates, max_span_width, emb] with tf.compat.v1.variable_scope("head_scores"): self.head_scores = util_tf2.projection(context_outputs, 1) # [num_words, 1] span_head_scores = tf.gather(self.head_scores, span_indices) # [num_candidates, max_span_width, 1] span_mask = tf.expand_dims(tf.sequence_mask(span_width, self.config["max_span_width"], dtype=tf.float32), 2) # [num_candidates, max_span_width, 1] span_head_scores += tf.math.log(span_mask) # [num_candidates, max_span_width, 1] span_attention = tf.nn.softmax(span_head_scores, 1) # [num_candidates, max_span_width, 1] span_head_emb = tf.reduce_sum(input_tensor=span_attention * span_text_emb, axis=1) # [num_candidates, emb] span_emb_list.append(span_head_emb) span_emb = tf.concat(span_emb_list, 1) # [num_candidates, emb] return span_emb # [num_candidates, emb]
def coarse_to_fine_pruning(self, top_span_emb, top_span_mention_scores, c): """ Args: top_span_emb: [k, emb] top_span_mention_scores: [k], mention scores c: top number """ k = util_tf2.shape(top_span_emb, 0) top_span_range = tf.range(k) # [k] antecedent_offsets = tf.expand_dims(top_span_range, 1) - tf.expand_dims(top_span_range, 0) # [k, k] antecedents_mask = antecedent_offsets >= 1 # [k, k] fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.expand_dims(top_span_mention_scores, 0) # [k, k] fast_antecedent_scores += tf.math.log(tf.cast(antecedents_mask, dtype=tf.float32)) # [k, k] fast_antecedent_scores += self.get_fast_antecedent_scores(top_span_emb) # [k, k] _, top_antecedents = tf.nn.top_k(fast_antecedent_scores, c, sorted=False) # [k, c] top_antecedents_mask = self.batch_gather(antecedents_mask, top_antecedents) # [k, c] top_fast_antecedent_scores = self.batch_gather(fast_antecedent_scores, top_antecedents) # [k, c] top_antecedent_offsets = self.batch_gather(antecedent_offsets, top_antecedents) # [k, c] return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets
def get_fast_antecedent_scores(self, top_span_emb): with tf.compat.v1.variable_scope("src_projection"): source_top_span_emb = tf.nn.dropout(util_tf2.projection(top_span_emb, util_tf2.shape(top_span_emb, -1)), 1 - (self.dropout)) # [k, emb] target_top_span_emb = tf.nn.dropout(top_span_emb, 1 - (self.dropout)) # [k, emb] return tf.matmul(source_top_span_emb, target_top_span_emb, transpose_b=True) # [k, k]
def get_predictions_and_loss(self, tokens, context_word_emb, head_word_emb, lm_emb, char_index, \ text_len, is_training, entity_starts, entity_ends, entity_labels): self.dropout = self.get_dropout(self.config["dropout_rate"], is_training) self.lexical_dropout = self.get_dropout(self.config["lexical_dropout_rate"], is_training) self.lstm_dropout = self.get_dropout(self.config["lstm_dropout_rate"], is_training) num_sentences = tf.shape(input=context_word_emb)[0] max_sentence_length = tf.shape(input=context_word_emb)[1] # embeddings # glove embedding + char embedding + elmo embedding context_emb_list = [context_word_emb] head_emb_list = [head_word_emb] # character embedding if self.config["char_embedding_size"] > 0: char_emb = tf.gather(tf.compat.v1.get_variable("char_embeddings", [len(self.char_dict), self.config["char_embedding_size"]]), \ char_index) # [num_sentences, max_sentence_length, max_word_length, char_embedding_size] flattened_char_emb = tf.reshape(char_emb, [num_sentences * max_sentence_length, util_tf2.shape(char_emb, 2), \ util_tf2.shape(char_emb, 3)]) # [num_sentences * max_sentence_length, max_word_length, char_embedding_size] flattened_aggregated_char_emb = self.cnn(flattened_char_emb, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb] aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [num_sentences, max_sentence_length, \ util_tf2.shape(flattened_aggregated_char_emb, 1)]) # [num_sentences, max_sentence_length, emb] context_emb_list.append(aggregated_char_emb) head_emb_list.append(aggregated_char_emb) # ELMo embedding # lm_emb: [num_sentence, max_sentence_length, lm_size, lm_layers] lm_emb_size = util_tf2.shape(lm_emb, 2) lm_num_layers = util_tf2.shape(lm_emb, 3) with tf.compat.v1.variable_scope("lm_aggregation"): self.lm_weights = tf.nn.softmax(tf.compat.v1.get_variable("lm_scores", [lm_num_layers], initializer=tf.compat.v1.constant_initializer(0.0))) self.lm_scaling = tf.compat.v1.get_variable("lm_scaling", [], initializer=tf.compat.v1.constant_initializer(1.0)) flattened_lm_emb = tf.reshape(lm_emb, [num_sentences * max_sentence_length * lm_emb_size, lm_num_layers]) # [num_sentences * max_sentence_length * lm_emb_size, lm_emb_layers] flattened_aggregated_lm_emb = tf.matmul(flattened_lm_emb, tf.expand_dims(self.lm_weights, 1)) # [num_sentences * max_sentence_length * emb, 1] aggregated_lm_emb = tf.reshape(flattened_aggregated_lm_emb, [num_sentences, max_sentence_length, lm_emb_size]) aggregated_lm_emb *= self.lm_scaling context_emb_list.append(aggregated_lm_emb) # concatenate embeddings context_emb = tf.concat(context_emb_list, 2) # [num_sentences, max_sentence_length, emb] head_emb = tf.concat(head_emb_list, 2) # [num_sentences, max_sentence_length, emb] # dropout context_emb = tf.nn.dropout(context_emb, 1 - (self.lexical_dropout)) # [num_sentences, max_sentence_length, emb] head_emb = tf.nn.dropout(head_emb, 1 - (self.lexical_dropout)) # [num_sentences, max_sentence_length, emb] # embedding part done # sequence_mask: # orignal tensor t[d_1, d_2,..., d_n] # mask[i_1, i_2, ..., i_n, j] = t[i_1, i_2, ..., i_n] < j text_len_mask = tf.sequence_mask(text_len, maxlen=max_sentence_length) # [num_sentence, max_sentence_length] # bi-directional lstm # every word gets an embedding context_outputs = self.lstm_contextualize(context_emb, text_len, text_len_mask) # [num_words, emb] num_words = util_tf2.shape(context_outputs, 0) # handle spans sentence_indices = tf.tile(tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length] flattened_sentence_indices = self.flatten_emb_by_sentence(sentence_indices, text_len_mask) # [num_words] flattened_head_emb = self.flatten_emb_by_sentence(head_emb, text_len_mask) # [num_words] candidate_starts = tf.tile(tf.expand_dims(tf.range(num_words), 1), [1, self.max_span_width]) # [num_words, max_span_width] candidate_ends = candidate_starts + tf.expand_dims(tf.range(self.max_span_width), 0) # [num_words, max_span_width] candidate_start_sentence_indices = tf.gather(flattened_sentence_indices, candidate_starts) # [num_words, max_span_width] candidate_end_sentence_indices = tf.gather(flattened_sentence_indices, tf.minimum(candidate_ends, num_words - 1)) # [num_words, max_span_width] # candidate spans must come from the same sentence candidate_mask = tf.logical_and(candidate_ends < num_words, tf.equal(candidate_start_sentence_indices, candidate_end_sentence_indices)) # [num_words, max_span_width] flattened_candidate_mask = tf.reshape(candidate_mask, [-1]) # [num_words * max_span_width] candidate_starts = tf.boolean_mask(tensor=tf.reshape(candidate_starts, [-1]), mask=flattened_candidate_mask) # [num_candidates] candidate_ends = tf.boolean_mask(tensor=tf.reshape(candidate_ends, [-1]), mask=flattened_candidate_mask) # [num_candidates] candidate_sentence_indices = tf.boolean_mask(tensor=tf.reshape(candidate_start_sentence_indices, [-1]), mask=flattened_candidate_mask) # [num_candidates] # get labels candidate_entity_labels = self.get_entity_labels(candidate_starts, candidate_ends, entity_starts, entity_ends, entity_labels) # [num_candidates] candidate_span_emb = self.get_span_emb(flattened_head_emb, context_outputs, candidate_starts, candidate_ends) # [num_candidates, emb] candidate_mention_scores = self.get_mention_scores(candidate_span_emb) # [num_candidates, 1] candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [num_candidates] # filter out part of spans k = tf.cast(tf.floor(tf.cast(tf.shape(input=context_outputs)[0], dtype=tf.float32) * self.config["top_span_ratio"]), dtype=tf.int32) top_span_indices = coref_ops.extract_spans( tf.expand_dims(candidate_mention_scores, 0), tf.expand_dims(candidate_starts, 0), tf.expand_dims(candidate_ends, 0), tf.expand_dims(k, 0), util_tf2.shape(context_outputs, 0), True) # [1, k] top_span_indices.set_shape([1, None]) top_span_indices = tf.squeeze(top_span_indices, 0) # [k] top_span_starts = tf.gather(candidate_starts, top_span_indices) # [k] top_span_ends = tf.gather(candidate_ends, top_span_indices) # [k] top_span_emb = tf.gather(candidate_span_emb, top_span_indices) # [k, emb] top_span_mention_scores = tf.gather(candidate_mention_scores, top_span_indices) # [k] top_span_sentence_indices = tf.gather(candidate_sentence_indices, top_span_indices) # [k] top_span_entity_labels = tf.gather(candidate_entity_labels, top_span_indices) # [k] # entity scores self.entity_scores = self.get_entity_scores(top_span_emb) self.entity_labels_mask = self.get_entity_label_mask(top_span_entity_labels) # entity loss function entity_loss = self.get_entity_loss(self.entity_scores, self.entity_labels_mask) # [] return [self.entity_scores, self.entity_labels_mask], entity_loss