Пример #1
0
    def get_mention_emb(self, text_emb, text_outputs, mention_starts,
                        mention_ends):
        mention_emb_list = []

        mention_start_emb = tf.gather(text_outputs,
                                      mention_starts)  # [num_mentions, emb]
        mention_emb_list.append(mention_start_emb)

        mention_end_emb = tf.gather(text_outputs,
                                    mention_ends)  # [num_mentions, emb]
        mention_emb_list.append(mention_end_emb)

        mention_output = tf.concat([mention_start_emb, mention_end_emb], 1)
        mention_output = tf.tile(tf.expand_dims(mention_output, 1),
                                 [1, self.config["max_mention_width"], 1])

        mention_output_score = util.projection_name(mention_output, 1,
                                                    "mention_output_score")

        mention_width = 1 + mention_ends - mention_starts  # [num_mentions]
        if self.config["use_features"]:
            mention_width_index = mention_width - 1  # [num_mentions]
            mention_width_emb = tf.gather(
                tf.get_variable("mention_width_embeddings", [
                    self.config["max_mention_width"],
                    self.config["feature_size"]
                ]), mention_width_index)  # [num_mentions, emb]
            mention_width_emb = tf.nn.dropout(mention_width_emb, self.dropout)
            mention_emb_list.append(mention_width_emb)

        if self.config["model_heads"]:
            mention_indices = tf.expand_dims(
                tf.range(
                    self.config["max_mention_width"]), 0) + tf.expand_dims(
                        mention_starts, 1)  # [num_mentions, max_mention_width]
            mention_indices = tf.minimum(
                util.shape(text_outputs, 0) - 1,
                mention_indices)  # [num_mentions, max_mention_width]
            mention_text_emb = tf.gather(
                text_emb,
                mention_indices)  # [num_mentions, max_mention_width, emb]
            self.head_scores = util.projection(text_emb, 1)  # [num_words, 1]
            mention_head_scores = tf.gather(
                self.head_scores, mention_indices
            ) + mention_output_score  # [num_mentions, max_mention_width, 1]
            mention_mask = tf.expand_dims(
                tf.sequence_mask(mention_width,
                                 self.config["max_mention_width"],
                                 dtype=tf.float32),
                2)  # [num_mentions, max_mention_width, 1]
            mention_attention = tf.nn.softmax(
                mention_head_scores + tf.log(mention_mask),
                dim=1)  # [num_mentions, max_mention_width, 1]
            mention_head_emb = tf.reduce_sum(
                mention_attention * mention_text_emb, 1)  # [num_mentions, emb]
            mention_emb_list.append(mention_head_emb)

        mention_emb = tf.concat(mention_emb_list, 1)  # [num_mentions, emb]
        return mention_emb
Пример #2
0
  def get_context_antecedent_scores(self,
                                    mention_emb,
                                    mention_scores,
                                    antecedents,
                                    antecedents_len,
                                    mention_starts,
                                    mention_ends,
                                    mention_speaker_ids,
                                    genre_emb,
                                    context_starts,
                                    context_ends,
                                    text_outputs,
                                    text_emb):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    self.num_words = tf.shape(text_outputs)
    self.num_mentions = num_mentions

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]


    #############################
    #
    # Get matrix for co-attention
    #
    #############################
    

    ####### Mention Level #######
    
    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]

    mention_features = tf.concat([mention_start_emb, mention_end_emb], 1)
    
    mention_width = 1 + mention_ends - mention_starts # [num_mentions]
    mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
    mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
    mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]

    antecedent_indices = tf.gather(mention_indices, antecedents)
    antecedent_mask = tf.gather(mention_mask, antecedents)
    antecedent_indices_emb = tf.gather(text_outputs, antecedent_indices)

    target_indices = tf.tile(tf.expand_dims(mention_indices, 1), [1, max_antecedents, 1])
    target_mask = tf.tile(tf.expand_dims(mention_mask, 1), [1, max_antecedents, 1, 1])
    target_indices_emb = tf.gather(text_outputs, target_indices)


    ####### Context Level #######

    context_start_emb = tf.gather(text_outputs, context_starts)
    context_end_emb = tf.gather(text_outputs, context_ends)

    context_width = 1 + context_ends - context_starts
    context_indices = tf.expand_dims(tf.range(self.config["max_context_width"]), 0) + tf.expand_dims(context_starts, 1) # [num_mentions, max_mention_width]
    context_indices = tf.minimum(util.shape(text_outputs, 0) - 1, context_indices) # [num_mentions, max_mention_width]
    context_mask = tf.expand_dims(tf.sequence_mask(context_width, self.config["max_context_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]

    antecedent_context_indices = tf.gather(context_indices, antecedents)
    antecedent_context_mask = tf.gather(context_mask, antecedents)
    antecedent_context_indices_emb = tf.gather(text_outputs, antecedent_context_indices)

    target_context_indices = tf.tile(tf.expand_dims(context_indices, 1), [1, max_antecedents, 1])
    target_context_mask = tf.tile(tf.expand_dims(context_mask, 1), [1, max_antecedents, 1, 1])
    target_context_indices_emb = tf.gather(text_outputs, target_context_indices)


    #### Initial Embeddings #####
    
    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    
    context_emb = tf.concat([context_start_emb, context_end_emb], 1)

    antecedent_context_emb = tf.gather(context_emb, antecedents) # [num_mentions, max_ant, emb]
    target_context_emb_tiled = tf.tile(tf.expand_dims(context_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]

    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]
    

    #############################
    #
    # Calculate Co-attention
    #
    #############################


    ###### C_a Attention ########

    window_emb = tf.concat([antecedent_emb, target_emb_tiled, target_context_emb_tiled], 2)
    window_scores = util.projection_name(window_emb, 100, 'c_a_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_context_width'], 1])

    target_scores = util.projection_name(antecedent_context_indices_emb, 100, 'c_a_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')

    temp_att = tf.nn.softmax(temp_scores + tf.log(antecedent_context_mask), dim=2) # [num_mentions, max_mention_width, 1]
    antecedent_context_emb = tf.reduce_sum(temp_att * tf.gather(text_emb, antecedent_context_indices), 2)


    ###### C_t Attention ########

    window_emb = tf.concat([antecedent_emb, target_emb_tiled, antecedent_context_emb], 2)
    window_scores = util.projection_name(window_emb, 100, 'c_t_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_context_width'], 1])

    target_scores = util.projection_name(target_context_indices_emb, 100, 'c_t_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score') 

    temp_att = tf.nn.softmax(temp_scores + tf.log(target_context_mask), dim=2) # [num_mentions, max_mention_width, 1]
    target_context_emb_tiled = tf.reduce_sum(temp_att * tf.gather(text_emb, target_context_indices), 2)

    
    ###### M_t Attention ########

    window_emb = tf.concat([antecedent_emb, antecedent_context_emb, target_context_emb_tiled], 2)
    window_scores = util.projection_name(window_emb, 100, 'm_t_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_mention_width'], 1])

    target_scores = util.projection_name(target_indices_emb, 100, 'm_t_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')
    
    temp_att = tf.nn.softmax(temp_scores + tf.log(target_mask), dim=2) # [num_mentions, max_mention_width, 1]
    target_emb_tiled = tf.reduce_sum(temp_att * tf.gather(text_emb, target_indices), 2)


    ###### M_a Attention ########

    window_emb = tf.concat([target_emb_tiled, target_context_emb_tiled, antecedent_context_emb], 2)
    window_scores = util.projection_name(window_emb, 100, 'm_a_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_mention_width'], 1])

    target_scores = util.projection_name(antecedent_indices_emb, 100, 'm_a_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')
  
    temp_att = tf.nn.softmax(temp_scores + tf.log(antecedent_mask), dim=2) # [num_mentions, max_mention_width, 1]
    antecedent_emb = tf.reduce_sum(temp_att * tf.gather(text_emb, antecedent_indices), 2)
    

    #############################
    #
    # Calculate Pair Embeddings
    #
    #############################

    antecedent_feature = tf.gather(mention_features, antecedents) # [num_mentions, max_ant, emb]
    target_feature = tf.tile(tf.expand_dims(mention_features, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    # similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    # pair_emb = tf.concat([target_emb_tiled_1, antecedent_emb_1, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    pair_emb = tf.concat([
                          target_feature,
                          target_emb_tiled,
                          antecedent_feature,
                          antecedent_emb,
                          antecedent_context_emb,
                          target_context_emb_tiled,
                          similarity_emb,
                          feature_emb], 2)

    '''
    pair_emb = tf.nn.relu(util.projection_name(target_emb_tiled, self.config['ffnn_size'], 'comp_mt') +\
                util.projection_name(antecedent_emb, self.config['ffnn_size'], 'comp_ma') +\
                util.projection_name(antecedent_context_emb_1, self.config['ffnn_size'], 'comp_ca') +\
                util.projection_name(target_context_emb_tiled_1, self.config['ffnn_size'], 'comp_ct') +\
                util.projection_name(similarity_emb, self.config['ffnn_size'], 'comp_sim') +\
                util.projection_name(feature_emb, self.config['ffnn_size'], 'comp_feature'))
    '''

    #############################

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores  # [num_mentions, max_ant + 1]
Пример #3
0
  def get_predictions_and_loss(self,
                              word_emb,
                              char_index,
                              text_len,
                              speaker_ids,
                              genre,
                              is_training,
                              gold_starts,
                              gold_ends,
                              cluster_ids,
                              tag_labels,
                              tag_seq,
                              tag_loss_label):

    # self.gold_starts = gold_starts
    # self.gold_ends = gold_ends
    # self.cluster_ids = cluster_ids

    self.dropout = 1 - (tf.to_float(is_training) * self.config["dropout_rate"])
    self.lexical_dropout = 1 - (tf.to_float(is_training) * self.config["lexical_dropout_rate"])

    num_sentences = tf.shape(word_emb)[0]
    max_sentence_length = tf.shape(word_emb)[1]

    text_emb_list = [word_emb]

    if self.config["char_embedding_size"] > 0:
      char_emb = tf.gather(tf.get_variable("char_embeddings", [len(self.char_dict), self.config["char_embedding_size"]]), char_index) # [num_sentences, max_sentence_length, max_word_length, emb]
      flattened_char_emb = tf.reshape(char_emb, [num_sentences * max_sentence_length, util.shape(char_emb, 2), util.shape(char_emb, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb]
      flattened_aggregated_char_emb = util.cnn(flattened_char_emb, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb]
      aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [num_sentences, max_sentence_length, util.shape(flattened_aggregated_char_emb, 1)]) # [num_sentences, max_sentence_length, emb]
      text_emb_list.append(aggregated_char_emb)

    text_emb = tf.concat(text_emb_list, 2)
    text_emb = tf.nn.dropout(text_emb, self.lexical_dropout)

    text_len_mask = tf.sequence_mask(text_len, maxlen=max_sentence_length)
    text_len_mask = tf.reshape(text_len_mask, [num_sentences * max_sentence_length])
    # self.text_len_mask = text_len_mask[0]

    text_outputs = self.encode_sentences(text_emb, text_len, text_len_mask)
    text_outputs = tf.nn.dropout(text_outputs, self.dropout)

    genre_emb = tf.gather(tf.get_variable("genre_embeddings", [len(self.genres), self.config["feature_size"]]), genre) # [emb]

    sentence_indices = tf.tile(tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length]
    flattened_sentence_indices = self.flatten_emb_by_sentence(sentence_indices, text_len_mask) # [num_words]
    flattened_text_emb = self.flatten_emb_by_sentence(text_emb, text_len_mask) # [num_words]
    self.flattened_sentence_indices = flattened_sentence_indices

    # text_conv = tf.expand_dims(text_outputs, 0)
    text_conv = tf.expand_dims(flattened_text_emb, 0)
    text_conv = util.cnn_name(text_conv, [5], 100, 'tag_conv')[0]
    text_conv = tf.nn.dropout(text_conv, self.dropout)

    # text_lstm = self.encode_sentences_unilstm(text_conv)[0]

    # tag_prob = tf.nn.softmax(util.projection_name(text_conv, 3, 'tag_fc'), dim=1)
    tag_prob = util.projection_name(text_conv, 3, 'tag_fc')
    # tag_prob_transpose = tf.transpose(tag_prob, [1, 0])

    tag_outputs = tf.argmax(tag_prob, axis=1, output_type=tf.int32)

    tag_high = tf.reduce_max(tag_prob, axis=1)

    num_words = tf.shape(text_conv)[0]

    # self.lstm_shape = tf.shape(text_outputs)
    # self.conv_shape = tf.shape(text_conv)

    # candidate_starts, candidate_ends = coref_ops.spans(
    #   sentence_indices=flattened_sentence_indices,
    #   max_width=self.max_mention_width)
    # candidate_starts.set_shape([None])
    # candidate_ends.set_shape([None])

    mention_starts, mention_ends, mention_scores = coref_ops.memory(
      tag_seq=tag_outputs,
      tag_high=tag_high,
      num_words=1)
    mention_starts.set_shape([None])
    mention_ends.set_shape([None])
    mention_scores.set_shape([None])

    self.num_mention = tf.shape(mention_starts)[0]
    self.num_gold_mention = tf.shape(gold_starts)[0]
    self.num_words = num_words
    self.mention_starts = mention_starts
    self.gold_starts = gold_starts
    self.mention_ends = mention_ends
    self.tag_outputs = tag_outputs
    self.tag_seq = tag_seq

    mention_emb = self.get_mention_emb(flattened_text_emb, text_outputs, mention_starts, mention_ends) # [num_candidates, emb]
    # mention_scores = tf.convert_to_tensor([self.get_mention_prob(tag_prob_transpose, mention_starts[i], mention_ends[i], num_words)
    #                                         for i in range(tf.shape(mention_starts)[0])])

    # mention_scores = tf.squeeze(self.get_mention_scores(mention_emb), 1) # [num_mentions, 1]
    # candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [num_mentions]

    # k = tf.to_int32(tf.floor(tf.to_float(tf.shape(text_outputs)[0]) * self.config["mention_ratio"]))
    # predicted_mention_indices = coref_ops.extract_mentions(candidate_mention_scores, candidate_starts, candidate_ends, k) # ([k], [k])
    # predicted_mention_indices.set_shape([None])

    # mention_starts = tf.gather(candidate_starts, predicted_mention_indices) # [num_mentions]
    # mention_ends = tf.gather(candidate_ends, predicted_mention_indices) # [num_mentions]
    # mention_emb = tf.gather(candidate_mention_emb, predicted_mention_indices) # [num_mentions, emb]
    # mention_scores = tf.gather(candidate_mention_scores, predicted_mention_indices) # [num_mentions]

    candidate_starts = mention_starts
    candidate_ends = mention_ends

    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
    mention_speaker_ids = tf.gather(speaker_ids, mention_starts) # [num_mentions]

    max_antecedents = self.config["max_antecedents"]
    antecedents, antecedent_labels, antecedents_len = coref_ops.antecedents(mention_starts, mention_ends, gold_starts, gold_ends, cluster_ids, max_antecedents) # ([num_mentions, max_ant], [num_mentions, max_ant + 1], [num_mentions]
    antecedents.set_shape([None, None])
    antecedent_labels.set_shape([None, None])
    antecedents_len.set_shape([None])

    antecedent_scores = self.get_antecedent_scores(mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb) # [num_mentions, max_ant + 1]

    raw_mention_loss = self.softmax_loss(antecedent_scores, antecedent_labels)# [num_mentions]
    raw_tagging_loss = tf.nn.softmax_cross_entropy_with_logits(logits=tag_prob, labels=tag_labels)
    mention_loss = tf.reduce_sum(raw_mention_loss)
    tagging_loss = tf.reduce_sum(tf.multiply(tf.to_float(tag_loss_label), raw_tagging_loss)) # [] 
    # tagging_loss = tf.reduce_sum(raw_tagging_loss)

    return [
            candidate_starts,
            candidate_ends,
            mention_scores,
            mention_starts,
            mention_ends,
            antecedents,
            antecedent_scores,
            tag_outputs,
            tag_seq
          ], mention_loss, tagging_loss
Пример #4
0
    def get_antecedent_scores(self, mention_emb, mention_scores, antecedents,
                              antecedents_len, mention_starts, mention_ends,
                              mention_speaker_ids, genre_emb, text_emb,
                              text_outputs, context_pre_starts,
                              context_pos_ends):
        num_mentions = util.shape(mention_emb, 0)
        max_antecedents = util.shape(antecedents, 1)

        feature_emb_list = []

        if self.config["use_metadata"]:
            antecedent_speaker_ids = tf.gather(
                mention_speaker_ids, antecedents)  # [num_mentions, max_ant]
            same_speaker = tf.equal(
                tf.expand_dims(mention_speaker_ids, 1),
                antecedent_speaker_ids)  # [num_mentions, max_ant]
            speaker_pair_emb = tf.gather(
                tf.get_variable("same_speaker_emb",
                                [2, self.config["feature_size"]]),
                tf.to_int32(same_speaker))  # [num_mentions, max_ant, emb]
            feature_emb_list.append(speaker_pair_emb)

            tiled_genre_emb = tf.tile(
                tf.expand_dims(tf.expand_dims(genre_emb, 0), 0),
                [num_mentions, max_antecedents, 1
                 ])  # [num_mentions, max_ant, emb]
            feature_emb_list.append(tiled_genre_emb)

        if self.config["use_features"]:
            target_indices = tf.range(num_mentions)  # [num_mentions]
            mention_distance = tf.expand_dims(
                target_indices, 1) - antecedents  # [num_mentions, max_ant]
            mention_distance_bins = coref_ops.distance_bins(
                mention_distance)  # [num_mentions, max_ant]
            mention_distance_bins.set_shape([None, None])
            mention_distance_emb = tf.gather(
                tf.get_variable("mention_distance_emb",
                                [10, self.config["feature_size"]]),
                mention_distance_bins)  # [num_mentions, max_ant]
            feature_emb_list.append(mention_distance_emb)

        feature_emb = tf.concat(feature_emb_list,
                                2)  # [num_mentions, max_ant, emb]
        feature_emb = tf.nn.dropout(
            feature_emb, self.dropout)  # [num_mentions, max_ant, emb]

        ########### Context Embeddings #################

        context_pre_ends = mention_starts - 1
        context_pos_starts = mention_ends + 1

        context_pre_width = mention_starts - context_pre_starts
        context_pos_width = context_pos_ends - mention_ends

        context_start_emb = tf.gather(text_outputs, context_pre_starts)
        context_end_emb = tf.gather(text_outputs, context_pos_ends)

        context_output = tf.concat([context_start_emb, context_end_emb], 1)
        context_output = tf.tile(tf.expand_dims(context_output, 1),
                                 [1, self.config["max_context_width"], 1])

        mention_output = tf.tile(tf.expand_dims(mention_emb, 1),
                                 [1, self.config["max_context_width"], 1])

        # context_width = 1 + context_ends - context_starts
        context_pre_indices = tf.expand_dims(
            tf.range(
                self.config["max_context_width"] / 2), 0) + tf.expand_dims(
                    context_pre_starts, 1)  # [num_mentions, max_mention_width]
        context_pre_indices = tf.minimum(
            util.shape(text_outputs, 0) - 1,
            context_pre_indices)  # [num_mentions, max_mention_width]
        context_pre_mask = tf.expand_dims(
            tf.sequence_mask(context_pre_width,
                             self.config["max_context_width"] / 2,
                             dtype=tf.float32),
            2)  # [num_mentions, max_mention_width, 1]

        context_pos_indices = tf.expand_dims(
            tf.range(
                self.config["max_context_width"] / 2), 0) + tf.expand_dims(
                    context_pos_starts, 1)  # [num_mentions, max_mention_width]
        context_pos_indices = tf.minimum(
            util.shape(text_outputs, 0) - 1,
            context_pos_indices)  # [num_mentions, max_mention_width]
        context_pos_mask = tf.expand_dims(
            tf.sequence_mask(context_pos_width,
                             self.config["max_context_width"] / 2,
                             dtype=tf.float32),
            2)  # [num_mentions, max_mention_width, 1]

        context_indices = tf.concat([context_pre_indices, context_pos_indices],
                                    1)
        context_mask = tf.concat([context_pre_mask, context_pos_mask], 1)

        context_glove_emb = tf.gather(text_emb, context_indices)

        context_att_score = util.projection_name(
            tf.concat([context_glove_emb, context_output, mention_output], 2),
            1, "context_att")

        context_attention = tf.nn.softmax(
            context_att_score + tf.log(context_mask),
            dim=1)  # [num_mentions, max_mention_width, 1]

        context_emb = tf.reduce_sum(context_attention * context_glove_emb,
                                    1)  # [num_mentions, emb]

        mention_emb = tf.concat([context_emb, mention_emb], 1)

        ################################################

        antecedent_emb = tf.gather(mention_emb,
                                   antecedents)  # [num_mentions, max_ant, emb]
        self.mention_emb_shape = tf.shape(mention_emb)
        self.mention_start_shape = tf.shape(antecedents)
        target_emb_tiled = tf.tile(
            tf.expand_dims(mention_emb, 1),
            [1, max_antecedents, 1])  # [num_mentions, max_ant, emb]
        similarity_emb = antecedent_emb * target_emb_tiled  # [num_mentions, max_ant, emb]

        pair_emb = tf.concat(
            [target_emb_tiled, antecedent_emb, similarity_emb, feature_emb],
            2)  # [num_mentions, max_ant, emb]

        with tf.variable_scope("iteration"):
            with tf.variable_scope("antecedent_scoring"):
                antecedent_scores = util.ffnn(
                    pair_emb, self.config["ffnn_depth"],
                    self.config["ffnn_size"], 1,
                    self.dropout)  # [num_mentions, max_ant, 1]
        antecedent_scores = tf.squeeze(antecedent_scores,
                                       2)  # [num_mentions, max_ant]

        antecedent_mask = tf.log(
            tf.sequence_mask(antecedents_len,
                             max_antecedents,
                             dtype=tf.float32))  # [num_mentions, max_ant]
        antecedent_scores += antecedent_mask  # [num_mentions, max_ant]

        antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(
            mention_scores, antecedents)  # [num_mentions, max_ant]
        antecedent_scores = tf.concat(
            [tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores],
            1)  # [num_mentions, max_ant + 1]
        return antecedent_scores  # [num_mentions, max_ant + 1]