def get_slow_antecedent_scores(self, top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb): k = util.shape(top_span_emb, 0) c = util.shape(top_antecedents, 1) feature_emb_list = [] if self.config["use_metadata"]: top_antecedent_speaker_ids = tf.gather(top_span_speaker_ids, top_antecedents) # [k, c] same_speaker = tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids) # [k, c] speaker_pair_emb = tf.gather( tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [k, c, emb] feature_emb_list.append(speaker_pair_emb) tiled_genre_emb = tf.tile( tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [k, c, 1]) # [k, c, emb] feature_emb_list.append(tiled_genre_emb) if self.config["use_features"]: antecedent_distance_buckets = self.bucket_distance( top_antecedent_offsets) # [k, c] antecedent_distance_emb = tf.gather( tf.get_variable("antecedent_distance_emb", [10, self.config["feature_size"]]), antecedent_distance_buckets) # [k, c] feature_emb_list.append(antecedent_distance_emb) feature_emb = tf.concat(feature_emb_list, 2) # [k, c, emb] feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [k, c, emb] target_emb = tf.expand_dims(top_span_emb, 1) # [k, 1, emb] similarity_emb = top_antecedent_emb * target_emb # [k, c, emb] target_emb = tf.tile(target_emb, [1, c, 1]) # [k, c, emb] pair_emb = tf.concat( [target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb] with tf.variable_scope("slow_antecedent_scores"): slow_antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [k, c, 1] slow_antecedent_scores = tf.squeeze(slow_antecedent_scores, 2) # [k, c] return slow_antecedent_scores # [k, c]
def get_fast_antecedent_scores(self, top_span_emb): with tf.variable_scope("src_projection"): source_top_span_emb = tf.nn.dropout( util.projection(top_span_emb, util.shape(top_span_emb, -1)), self.dropout) # [k, emb] target_top_span_emb = tf.nn.dropout(top_span_emb, self.dropout) # [k, emb] return tf.matmul(source_top_span_emb, target_top_span_emb, transpose_b=True) # [k, k]
def distance_pruning(self, top_span_emb, top_span_mention_scores, c): k = util.shape(top_span_emb, 0) top_antecedent_offsets = tf.tile(tf.expand_dims(tf.range(c) + 1, 0), [k, 1]) # [k, c] raw_top_antecedents = tf.expand_dims( tf.range(k), 1) - top_antecedent_offsets # [k, c] top_antecedents_mask = raw_top_antecedents >= 0 # [k, c] top_antecedents = tf.maximum(raw_top_antecedents, 0) # [k, c] top_fast_antecedent_scores = tf.expand_dims( top_span_mention_scores, 1) + tf.gather(top_span_mention_scores, top_antecedents) # [k, c] top_fast_antecedent_scores += tf.log( tf.to_float(top_antecedents_mask)) # [k, c] return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets
def get_span_emb(self, head_emb, context_outputs, span_starts, span_ends): span_emb_list = [] span_start_emb = tf.gather(context_outputs, span_starts) # [k, emb] span_emb_list.append(span_start_emb) span_end_emb = tf.gather(context_outputs, span_ends) # [k, emb] span_emb_list.append(span_end_emb) span_width = 1 + span_ends - span_starts # [k] if self.config["use_features"]: span_width_index = span_width - 1 # [k] span_width_emb = tf.gather( tf.get_variable("span_width_embeddings", [ self.config["max_span_width"], self.config["feature_size"] ]), span_width_index) # [k, emb] span_width_emb = tf.nn.dropout(span_width_emb, self.dropout) span_emb_list.append(span_width_emb) if self.config["model_heads"]: span_indices = tf.expand_dims( tf.range(self.config["max_span_width"]), 0) + tf.expand_dims( span_starts, 1) # [k, max_span_width] span_indices = tf.minimum( util.shape(context_outputs, 0) - 1, span_indices) # [k, max_span_width] span_text_emb = tf.gather(head_emb, span_indices) # [k, max_span_width, emb] with tf.variable_scope("head_scores"): self.head_scores = util.projection(context_outputs, 1) # [num_words, 1] span_head_scores = tf.gather( self.head_scores, span_indices) # [k, max_span_width, 1] span_mask = tf.expand_dims( tf.sequence_mask(span_width, self.config["max_span_width"], dtype=tf.float32), 2) # [k, max_span_width, 1] span_head_scores += tf.log(span_mask) # [k, max_span_width, 1] span_attention = tf.nn.softmax(span_head_scores, 1) # [k, max_span_width, 1] span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1) # [k, emb] span_emb_list.append(span_head_emb) span_emb = tf.concat(span_emb_list, 1) # [k, emb] return span_emb # [k, emb]
def flatten_emb_by_sentence(self, emb, text_len_mask): num_sentences = tf.shape(emb)[0] max_sentence_length = tf.shape(emb)[1] emb_rank = len(emb.get_shape()) if emb_rank == 2: flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length]) elif emb_rank == 3: flattened_emb = tf.reshape( emb, [num_sentences * max_sentence_length, util.shape(emb, 2)]) else: raise ValueError("Unsupported rank: {}".format(emb_rank)) return tf.boolean_mask( flattened_emb, tf.reshape(text_len_mask, [num_sentences * max_sentence_length]))
def lstm_contextualize(self, text_emb, text_len, text_len_mask): num_sentences = tf.shape(text_emb)[0] current_inputs = text_emb # [num_sentences, max_sentence_length, emb] for layer in range(self.config["contextualization_layers"]): with tf.variable_scope("layer_{}".format(layer)): with tf.variable_scope("fw_cell"): cell_fw = util.CustomLSTMCell( self.config["contextualization_size"], num_sentences, self.lstm_dropout) with tf.variable_scope("bw_cell"): cell_bw = util.CustomLSTMCell( self.config["contextualization_size"], num_sentences, self.lstm_dropout) state_fw = tf.contrib.rnn.LSTMStateTuple( tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1])) state_bw = tf.contrib.rnn.LSTMStateTuple( tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1])) (fw_outputs, bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=current_inputs, sequence_length=text_len, initial_state_fw=state_fw, initial_state_bw=state_bw) text_outputs = tf.concat( [fw_outputs, bw_outputs], 2) # [num_sentences, max_sentence_length, emb] text_outputs = tf.nn.dropout(text_outputs, self.lstm_dropout) if layer > 0: highway_gates = tf.sigmoid( util.projection(text_outputs, util.shape(text_outputs, 2)) ) # [num_sentences, max_sentence_length, emb] text_outputs = highway_gates * text_outputs + ( 1 - highway_gates) * current_inputs current_inputs = text_outputs return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
def coarse_to_fine_pruning(self, top_span_emb, top_span_mention_scores, c): k = util.shape(top_span_emb, 0) top_span_range = tf.range(k) # [k] antecedent_offsets = tf.expand_dims( top_span_range, 1) - tf.expand_dims(top_span_range, 0) # [k, k] antecedents_mask = antecedent_offsets >= 1 # [k, k] fast_antecedent_scores = tf.expand_dims( top_span_mention_scores, 1) + tf.expand_dims( top_span_mention_scores, 0) # [k, k] fast_antecedent_scores += tf.log( tf.to_float(antecedents_mask)) # [k, k] fast_antecedent_scores += self.get_fast_antecedent_scores( top_span_emb) # [k, k] _, top_antecedents = tf.nn.top_k(fast_antecedent_scores, c, sorted=False) # [k, c] top_antecedents_mask = util.batch_gather(antecedents_mask, top_antecedents) # [k, c] top_fast_antecedent_scores = util.batch_gather( fast_antecedent_scores, top_antecedents) # [k, c] top_antecedent_offsets = util.batch_gather(antecedent_offsets, top_antecedents) # [k, c] return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets
def get_predictions_and_loss(self, tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids): self.dropout = self.get_dropout(self.config["dropout_rate"], is_training) self.lexical_dropout = self.get_dropout( self.config["lexical_dropout_rate"], is_training) self.lstm_dropout = self.get_dropout(self.config["lstm_dropout_rate"], is_training) num_sentences = tf.shape(context_word_emb)[0] max_sentence_length = tf.shape(context_word_emb)[1] context_emb_list = [context_word_emb] head_emb_list = [head_word_emb] if self.config["char_embedding_size"] > 0: char_emb = tf.gather( tf.get_variable( "char_embeddings", [len(self.char_dict), self.config["char_embedding_size"]]), char_index ) # [num_sentences, max_sentence_length, max_word_length, emb] flattened_char_emb = tf.reshape(char_emb, [ num_sentences * max_sentence_length, util.shape(char_emb, 2), util.shape(char_emb, 3) ]) # [num_sentences * max_sentence_length, max_word_length, emb] flattened_aggregated_char_emb = util.cnn( flattened_char_emb, self.config["filter_widths"], self.config["filter_size"] ) # [num_sentences * max_sentence_length, emb] aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [ num_sentences, max_sentence_length, util.shape(flattened_aggregated_char_emb, 1) ]) # [num_sentences, max_sentence_length, emb] context_emb_list.append(aggregated_char_emb) head_emb_list.append(aggregated_char_emb) if not self.lm_file: elmo_module = hub.Module("https://tfhub.dev/google/elmo/2") lm_embeddings = elmo_module(inputs={ "tokens": tokens, "sequence_len": text_len }, signature="tokens", as_dict=True) word_emb = lm_embeddings[ "word_emb"] # [num_sentences, max_sentence_length, 512] lm_emb = tf.stack([ tf.concat([word_emb, word_emb], -1), lm_embeddings["lstm_outputs1"], lm_embeddings["lstm_outputs2"] ], -1) # [num_sentences, max_sentence_length, 1024, 3] lm_emb_size = util.shape(lm_emb, 2) lm_num_layers = util.shape(lm_emb, 3) with tf.variable_scope("lm_aggregation"): self.lm_weights = tf.nn.softmax( tf.get_variable("lm_scores", [lm_num_layers], initializer=tf.constant_initializer(0.0))) self.lm_scaling = tf.get_variable( "lm_scaling", [], initializer=tf.constant_initializer(1.0)) flattened_lm_emb = tf.reshape( lm_emb, [num_sentences * max_sentence_length * lm_emb_size, lm_num_layers]) flattened_aggregated_lm_emb = tf.matmul( flattened_lm_emb, tf.expand_dims( self.lm_weights, 1)) # [num_sentences * max_sentence_length * emb, 1] aggregated_lm_emb = tf.reshape( flattened_aggregated_lm_emb, [num_sentences, max_sentence_length, lm_emb_size]) aggregated_lm_emb *= self.lm_scaling context_emb_list.append(aggregated_lm_emb) context_emb = tf.concat(context_emb_list, 2) # [num_sentences, max_sentence_length, emb] head_emb = tf.concat(head_emb_list, 2) # [num_sentences, max_sentence_length, emb] context_emb = tf.nn.dropout( context_emb, self.lexical_dropout) # [num_sentences, max_sentence_length, emb] head_emb = tf.nn.dropout( head_emb, self.lexical_dropout) # [num_sentences, max_sentence_length, emb] text_len_mask = tf.sequence_mask( text_len, maxlen=max_sentence_length) # [num_sentence, max_sentence_length] context_outputs = self.lstm_contextualize( context_emb, text_len, text_len_mask) # [num_words, emb] num_words = util.shape(context_outputs, 0) genre_emb = tf.gather( tf.get_variable("genre_embeddings", [len(self.genres), self.config["feature_size"]]), genre) # [emb] sentence_indices = tf.tile( tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length] flattened_sentence_indices = self.flatten_emb_by_sentence( sentence_indices, text_len_mask) # [num_words] flattened_head_emb = self.flatten_emb_by_sentence( head_emb, text_len_mask) # [num_words] candidate_starts = tf.tile( tf.expand_dims(tf.range(num_words), 1), [1, self.max_span_width]) # [num_words, max_span_width] candidate_ends = candidate_starts + tf.expand_dims( tf.range(self.max_span_width), 0) # [num_words, max_span_width] candidate_start_sentence_indices = tf.gather( flattened_sentence_indices, candidate_starts) # [num_words, max_span_width] candidate_end_sentence_indices = tf.gather( flattened_sentence_indices, tf.minimum(candidate_ends, num_words - 1)) # [num_words, max_span_width] candidate_mask = tf.logical_and( candidate_ends < num_words, tf.equal( candidate_start_sentence_indices, candidate_end_sentence_indices)) # [num_words, max_span_width] flattened_candidate_mask = tf.reshape( candidate_mask, [-1]) # [num_words * max_span_width] candidate_starts = tf.boolean_mask( tf.reshape(candidate_starts, [-1]), flattened_candidate_mask) # [num_candidates] candidate_ends = tf.boolean_mask( tf.reshape(candidate_ends, [-1]), flattened_candidate_mask) # [num_candidates] candidate_sentence_indices = tf.boolean_mask( tf.reshape(candidate_start_sentence_indices, [-1]), flattened_candidate_mask) # [num_candidates] candidate_cluster_ids = self.get_candidate_labels( candidate_starts, candidate_ends, gold_starts, gold_ends, cluster_ids) # [num_candidates] candidate_span_emb = self.get_span_emb( flattened_head_emb, context_outputs, candidate_starts, candidate_ends) # [num_candidates, emb] candidate_mention_scores = self.get_mention_scores( candidate_span_emb) # [k, 1] candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [k] k = tf.to_int32( tf.floor( tf.to_float(tf.shape(context_outputs)[0]) * self.config["top_span_ratio"])) top_span_indices = coref_ops.extract_spans( tf.expand_dims(candidate_mention_scores, 0), tf.expand_dims(candidate_starts, 0), tf.expand_dims(candidate_ends, 0), tf.expand_dims(k, 0), util.shape(context_outputs, 0), True) # [1, k] top_span_indices.set_shape([1, None]) top_span_indices = tf.squeeze(top_span_indices, 0) # [k] top_span_starts = tf.gather(candidate_starts, top_span_indices) # [k] top_span_ends = tf.gather(candidate_ends, top_span_indices) # [k] top_span_emb = tf.gather(candidate_span_emb, top_span_indices) # [k, emb] top_span_cluster_ids = tf.gather(candidate_cluster_ids, top_span_indices) # [k] top_span_mention_scores = tf.gather(candidate_mention_scores, top_span_indices) # [k] top_span_sentence_indices = tf.gather(candidate_sentence_indices, top_span_indices) # [k] top_span_speaker_ids = tf.gather(speaker_ids, top_span_starts) # [k] c = tf.minimum(self.config["max_top_antecedents"], k) if self.config["coarse_to_fine"]: top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets = self.coarse_to_fine_pruning( top_span_emb, top_span_mention_scores, c) else: top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets = self.distance_pruning( top_span_emb, top_span_mention_scores, c) dummy_scores = tf.zeros([k, 1]) # [k, 1] for i in range(self.config["coref_depth"]): with tf.variable_scope("coref_layer", reuse=(i > 0)): top_antecedent_emb = tf.gather(top_span_emb, top_antecedents) # [k, c, emb] top_antecedent_scores = top_fast_antecedent_scores + self.get_slow_antecedent_scores( top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb) # [k, c] top_antecedent_weights = tf.nn.softmax( tf.concat([dummy_scores, top_antecedent_scores], 1)) # [k, c + 1] top_antecedent_emb = tf.concat( [tf.expand_dims(top_span_emb, 1), top_antecedent_emb], 1) # [k, c + 1, emb] attended_span_emb = tf.reduce_sum( tf.expand_dims(top_antecedent_weights, 2) * top_antecedent_emb, 1) # [k, emb] with tf.variable_scope("f"): f = tf.sigmoid( util.projection( tf.concat([top_span_emb, attended_span_emb], 1), util.shape(top_span_emb, -1))) # [k, emb] top_span_emb = f * attended_span_emb + ( 1 - f) * top_span_emb # [k, emb] top_antecedent_scores = tf.concat( [dummy_scores, top_antecedent_scores], 1) # [k, c + 1] top_antecedent_cluster_ids = tf.gather(top_span_cluster_ids, top_antecedents) # [k, c] top_antecedent_cluster_ids += tf.to_int32( tf.log(tf.to_float(top_antecedents_mask))) # [k, c] same_cluster_indicator = tf.equal(top_antecedent_cluster_ids, tf.expand_dims( top_span_cluster_ids, 1)) # [k, c] non_dummy_indicator = tf.expand_dims(top_span_cluster_ids > 0, 1) # [k, 1] pairwise_labels = tf.logical_and(same_cluster_indicator, non_dummy_indicator) # [k, c] dummy_labels = tf.logical_not( tf.reduce_any(pairwise_labels, 1, keepdims=True)) # [k, 1] top_antecedent_labels = tf.concat([dummy_labels, pairwise_labels], 1) # [k, c + 1] loss = self.softmax_loss(top_antecedent_scores, top_antecedent_labels) # [k] loss = tf.reduce_sum(loss) # [] return [ candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores ], loss