コード例 #1
0
  def get_antecedent_features(self,
                            mention_emb,
                            mention_scores,
                            antecedents,
                            antecedents_len,
                            mention_starts,
                            mention_ends,
                            mention_speaker_ids,
                            genre_emb):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]

    return feature_emb  # [num_mentions, max_ant + 1]
コード例 #2
0
  def get_antecedent_scores(self, mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb, mention_ner_ids):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      # tile is duplicating data [a b c d] --> [a b c d a b c d]
      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    if self.config["use_ner_phi"]:
      antecedent_ner_ids = tf.gather(mention_ner_ids, antecedents)
      same_ner = tf.equal(tf.expand_dims(mention_ner_ids, 1), antecedent_ner_ids)
      ner_pair_emb = tf.gather(tf.get_variable("same_ner_emb", [2, self.config["feature_size"]]), tf.to_int32(same_ner))
      feature_emb_list.append(ner_pair_emb)

    # phi(i, j)
    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]

    # g_i
    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    
    # g_j 
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    
    # g_i . g_j
    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    # [g_i, g_j, g_i . g_j, phi(i, j)]
    pair_emb = tf.concat([target_emb_tiled, antecedent_emb, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores, pair_emb # [num_mentions, max_ant + 1]
コード例 #3
0
ファイル: coref_model.py プロジェクト: qq547276542/e2e-coref
  def get_antecedent_scores(self, mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]

    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    pair_emb = tf.concat([target_emb_tiled, antecedent_emb, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores  # [num_mentions, max_ant + 1]
コード例 #4
0
  def get_context_antecedent_scores(self,
                                    mention_emb,
                                    mention_scores,
                                    antecedents,
                                    antecedents_len,
                                    mention_starts,
                                    mention_ends,
                                    mention_speaker_ids,
                                    genre_emb,
                                    context_starts,
                                    context_ends,
                                    text_outputs,
                                    text_emb):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    self.num_words = tf.shape(text_outputs)
    self.num_mentions = num_mentions

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]


    #############################
    #
    # Get matrix for co-attention
    #
    #############################
    

    ####### Mention Level #######
    
    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]

    mention_features = tf.concat([mention_start_emb, mention_end_emb], 1)
    
    mention_width = 1 + mention_ends - mention_starts # [num_mentions]
    mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
    mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
    mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]

    antecedent_indices = tf.gather(mention_indices, antecedents)
    antecedent_mask = tf.gather(mention_mask, antecedents)
    antecedent_indices_emb = tf.gather(text_outputs, antecedent_indices)

    target_indices = tf.tile(tf.expand_dims(mention_indices, 1), [1, max_antecedents, 1])
    target_mask = tf.tile(tf.expand_dims(mention_mask, 1), [1, max_antecedents, 1, 1])
    target_indices_emb = tf.gather(text_outputs, target_indices)


    ####### Context Level #######

    context_start_emb = tf.gather(text_outputs, context_starts)
    context_end_emb = tf.gather(text_outputs, context_ends)

    context_width = 1 + context_ends - context_starts
    context_indices = tf.expand_dims(tf.range(self.config["max_context_width"]), 0) + tf.expand_dims(context_starts, 1) # [num_mentions, max_mention_width]
    context_indices = tf.minimum(util.shape(text_outputs, 0) - 1, context_indices) # [num_mentions, max_mention_width]
    context_mask = tf.expand_dims(tf.sequence_mask(context_width, self.config["max_context_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]

    antecedent_context_indices = tf.gather(context_indices, antecedents)
    antecedent_context_mask = tf.gather(context_mask, antecedents)
    antecedent_context_indices_emb = tf.gather(text_outputs, antecedent_context_indices)

    target_context_indices = tf.tile(tf.expand_dims(context_indices, 1), [1, max_antecedents, 1])
    target_context_mask = tf.tile(tf.expand_dims(context_mask, 1), [1, max_antecedents, 1, 1])
    target_context_indices_emb = tf.gather(text_outputs, target_context_indices)


    #### Initial Embeddings #####
    
    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    
    context_emb = tf.concat([context_start_emb, context_end_emb], 1)

    antecedent_context_emb = tf.gather(context_emb, antecedents) # [num_mentions, max_ant, emb]
    target_context_emb_tiled = tf.tile(tf.expand_dims(context_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]

    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]
    

    #############################
    #
    # Calculate Co-attention
    #
    #############################


    ###### C_a Attention ########

    window_emb = tf.concat([antecedent_emb, target_emb_tiled, target_context_emb_tiled], 2)
    window_scores = util.projection_name(window_emb, 100, 'c_a_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_context_width'], 1])

    target_scores = util.projection_name(antecedent_context_indices_emb, 100, 'c_a_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')

    temp_att = tf.nn.softmax(temp_scores + tf.log(antecedent_context_mask), dim=2) # [num_mentions, max_mention_width, 1]
    antecedent_context_emb = tf.reduce_sum(temp_att * tf.gather(text_emb, antecedent_context_indices), 2)


    ###### C_t Attention ########

    window_emb = tf.concat([antecedent_emb, target_emb_tiled, antecedent_context_emb], 2)
    window_scores = util.projection_name(window_emb, 100, 'c_t_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_context_width'], 1])

    target_scores = util.projection_name(target_context_indices_emb, 100, 'c_t_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score') 

    temp_att = tf.nn.softmax(temp_scores + tf.log(target_context_mask), dim=2) # [num_mentions, max_mention_width, 1]
    target_context_emb_tiled = tf.reduce_sum(temp_att * tf.gather(text_emb, target_context_indices), 2)

    
    ###### M_t Attention ########

    window_emb = tf.concat([antecedent_emb, antecedent_context_emb, target_context_emb_tiled], 2)
    window_scores = util.projection_name(window_emb, 100, 'm_t_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_mention_width'], 1])

    target_scores = util.projection_name(target_indices_emb, 100, 'm_t_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')
    
    temp_att = tf.nn.softmax(temp_scores + tf.log(target_mask), dim=2) # [num_mentions, max_mention_width, 1]
    target_emb_tiled = tf.reduce_sum(temp_att * tf.gather(text_emb, target_indices), 2)


    ###### M_a Attention ########

    window_emb = tf.concat([target_emb_tiled, target_context_emb_tiled, antecedent_context_emb], 2)
    window_scores = util.projection_name(window_emb, 100, 'm_a_window')
    window_scores = tf.tile(tf.expand_dims(window_scores, 2), [1, 1, self.config['max_mention_width'], 1])

    target_scores = util.projection_name(antecedent_indices_emb, 100, 'm_a_target')

    temp_scores = util.projection_name(window_scores + target_scores, 1, 'att_score')
  
    temp_att = tf.nn.softmax(temp_scores + tf.log(antecedent_mask), dim=2) # [num_mentions, max_mention_width, 1]
    antecedent_emb = tf.reduce_sum(temp_att * tf.gather(text_emb, antecedent_indices), 2)
    

    #############################
    #
    # Calculate Pair Embeddings
    #
    #############################

    antecedent_feature = tf.gather(mention_features, antecedents) # [num_mentions, max_ant, emb]
    target_feature = tf.tile(tf.expand_dims(mention_features, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    # similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    # pair_emb = tf.concat([target_emb_tiled_1, antecedent_emb_1, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    pair_emb = tf.concat([
                          target_feature,
                          target_emb_tiled,
                          antecedent_feature,
                          antecedent_emb,
                          antecedent_context_emb,
                          target_context_emb_tiled,
                          similarity_emb,
                          feature_emb], 2)

    '''
    pair_emb = tf.nn.relu(util.projection_name(target_emb_tiled, self.config['ffnn_size'], 'comp_mt') +\
                util.projection_name(antecedent_emb, self.config['ffnn_size'], 'comp_ma') +\
                util.projection_name(antecedent_context_emb_1, self.config['ffnn_size'], 'comp_ca') +\
                util.projection_name(target_context_emb_tiled_1, self.config['ffnn_size'], 'comp_ct') +\
                util.projection_name(similarity_emb, self.config['ffnn_size'], 'comp_sim') +\
                util.projection_name(feature_emb, self.config['ffnn_size'], 'comp_feature'))
    '''

    #############################

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores  # [num_mentions, max_ant + 1]
コード例 #5
0
    def get_antecedent_scores(self, mention_emb, mention_scores, antecedents,
                              antecedents_len, mention_starts, mention_ends,
                              mention_speaker_ids, genre_emb, k):
        num_mentions = util.shape(mention_emb, 0)
        max_antecedents = util.shape(antecedents, 1)

        feature_emb_list = []

        if self.config["use_metadata"]:
            antecedent_speaker_ids = tf.gather(
                mention_speaker_ids, antecedents)  # [num_mentions, max_ant]
            same_speaker = tf.equal(
                tf.expand_dims(mention_speaker_ids, 1),
                antecedent_speaker_ids)  # [num_mentions, max_ant]
            speaker_pair_emb = tf.gather(
                tf.get_variable("same_speaker_emb",
                                [2, self.config["feature_size"]]),
                tf.to_int32(same_speaker))  # [num_mentions, max_ant, emb]
            feature_emb_list.append(speaker_pair_emb)

            tiled_genre_emb = tf.tile(
                tf.expand_dims(tf.expand_dims(genre_emb, 0), 0),
                [num_mentions, max_antecedents, 1
                 ])  # [num_mentions, max_ant, emb]
            feature_emb_list.append(tiled_genre_emb)

        if self.config["use_features"]:
            target_indices = tf.range(num_mentions)  # [num_mentions]
            mention_distance = tf.expand_dims(
                target_indices, 1) - antecedents  # [num_mentions, max_ant]
            mention_distance_bins = coref_ops.distance_bins(
                mention_distance)  # [num_mentions, max_ant]
            mention_distance_bins.set_shape([None, None])
            mention_distance_emb = tf.gather(
                tf.get_variable("mention_distance_emb",
                                [10, self.config["feature_size"]]),
                mention_distance_bins)  # [num_mentions, max_ant]
            feature_emb_list.append(mention_distance_emb)

        feature_emb = tf.concat(feature_emb_list,
                                2)  # [num_mentions, max_ant, emb]
        feature_emb = tf.nn.dropout(
            feature_emb, self.dropout)  # [num_mentions, max_ant, emb]

        span_emb = tf.expand_dims(mention_emb, 0)
        antecedent_scores = self.rgcn_tagging(span_emb, mention_scores,
                                              feature_emb,
                                              k)  # [1, num_words, 100] ?
        self.scores = antecedent_scores
        '''
    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    self.mention_emb_shape = tf.shape(mention_emb)
    self.mention_start_shape = tf.shape(antecedents)
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    pair_emb = tf.concat([target_emb_tiled, antecedent_emb, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    '''
        return antecedent_scores  # [num_mentions, max_ant + 1]
コード例 #6
0
    def get_antecedent_scores(self, mention_emb, mention_scores, antecedents,
                              antecedents_len, mention_starts, mention_ends,
                              mention_speaker_ids, genre_emb, text_emb,
                              text_outputs, context_pre_starts,
                              context_pos_ends):
        num_mentions = util.shape(mention_emb, 0)
        max_antecedents = util.shape(antecedents, 1)

        feature_emb_list = []

        if self.config["use_metadata"]:
            antecedent_speaker_ids = tf.gather(
                mention_speaker_ids, antecedents)  # [num_mentions, max_ant]
            same_speaker = tf.equal(
                tf.expand_dims(mention_speaker_ids, 1),
                antecedent_speaker_ids)  # [num_mentions, max_ant]
            speaker_pair_emb = tf.gather(
                tf.get_variable("same_speaker_emb",
                                [2, self.config["feature_size"]]),
                tf.to_int32(same_speaker))  # [num_mentions, max_ant, emb]
            feature_emb_list.append(speaker_pair_emb)

            tiled_genre_emb = tf.tile(
                tf.expand_dims(tf.expand_dims(genre_emb, 0), 0),
                [num_mentions, max_antecedents, 1
                 ])  # [num_mentions, max_ant, emb]
            feature_emb_list.append(tiled_genre_emb)

        if self.config["use_features"]:
            target_indices = tf.range(num_mentions)  # [num_mentions]
            mention_distance = tf.expand_dims(
                target_indices, 1) - antecedents  # [num_mentions, max_ant]
            mention_distance_bins = coref_ops.distance_bins(
                mention_distance)  # [num_mentions, max_ant]
            mention_distance_bins.set_shape([None, None])
            mention_distance_emb = tf.gather(
                tf.get_variable("mention_distance_emb",
                                [10, self.config["feature_size"]]),
                mention_distance_bins)  # [num_mentions, max_ant]
            feature_emb_list.append(mention_distance_emb)

        feature_emb = tf.concat(feature_emb_list,
                                2)  # [num_mentions, max_ant, emb]
        feature_emb = tf.nn.dropout(
            feature_emb, self.dropout)  # [num_mentions, max_ant, emb]

        ########### Context Embeddings #################

        context_pre_ends = mention_starts - 1
        context_pos_starts = mention_ends + 1

        context_pre_width = mention_starts - context_pre_starts
        context_pos_width = context_pos_ends - mention_ends

        context_start_emb = tf.gather(text_outputs, context_pre_starts)
        context_end_emb = tf.gather(text_outputs, context_pos_ends)

        context_output = tf.concat([context_start_emb, context_end_emb], 1)
        context_output = tf.tile(tf.expand_dims(context_output, 1),
                                 [1, self.config["max_context_width"], 1])

        mention_output = tf.tile(tf.expand_dims(mention_emb, 1),
                                 [1, self.config["max_context_width"], 1])

        # context_width = 1 + context_ends - context_starts
        context_pre_indices = tf.expand_dims(
            tf.range(
                self.config["max_context_width"] / 2), 0) + tf.expand_dims(
                    context_pre_starts, 1)  # [num_mentions, max_mention_width]
        context_pre_indices = tf.minimum(
            util.shape(text_outputs, 0) - 1,
            context_pre_indices)  # [num_mentions, max_mention_width]
        context_pre_mask = tf.expand_dims(
            tf.sequence_mask(context_pre_width,
                             self.config["max_context_width"] / 2,
                             dtype=tf.float32),
            2)  # [num_mentions, max_mention_width, 1]

        context_pos_indices = tf.expand_dims(
            tf.range(
                self.config["max_context_width"] / 2), 0) + tf.expand_dims(
                    context_pos_starts, 1)  # [num_mentions, max_mention_width]
        context_pos_indices = tf.minimum(
            util.shape(text_outputs, 0) - 1,
            context_pos_indices)  # [num_mentions, max_mention_width]
        context_pos_mask = tf.expand_dims(
            tf.sequence_mask(context_pos_width,
                             self.config["max_context_width"] / 2,
                             dtype=tf.float32),
            2)  # [num_mentions, max_mention_width, 1]

        context_indices = tf.concat([context_pre_indices, context_pos_indices],
                                    1)
        context_mask = tf.concat([context_pre_mask, context_pos_mask], 1)

        context_glove_emb = tf.gather(text_emb, context_indices)

        context_att_score = util.projection_name(
            tf.concat([context_glove_emb, context_output, mention_output], 2),
            1, "context_att")

        context_attention = tf.nn.softmax(
            context_att_score + tf.log(context_mask),
            dim=1)  # [num_mentions, max_mention_width, 1]

        context_emb = tf.reduce_sum(context_attention * context_glove_emb,
                                    1)  # [num_mentions, emb]

        mention_emb = tf.concat([context_emb, mention_emb], 1)

        ################################################

        antecedent_emb = tf.gather(mention_emb,
                                   antecedents)  # [num_mentions, max_ant, emb]
        self.mention_emb_shape = tf.shape(mention_emb)
        self.mention_start_shape = tf.shape(antecedents)
        target_emb_tiled = tf.tile(
            tf.expand_dims(mention_emb, 1),
            [1, max_antecedents, 1])  # [num_mentions, max_ant, emb]
        similarity_emb = antecedent_emb * target_emb_tiled  # [num_mentions, max_ant, emb]

        pair_emb = tf.concat(
            [target_emb_tiled, antecedent_emb, similarity_emb, feature_emb],
            2)  # [num_mentions, max_ant, emb]

        with tf.variable_scope("iteration"):
            with tf.variable_scope("antecedent_scoring"):
                antecedent_scores = util.ffnn(
                    pair_emb, self.config["ffnn_depth"],
                    self.config["ffnn_size"], 1,
                    self.dropout)  # [num_mentions, max_ant, 1]
        antecedent_scores = tf.squeeze(antecedent_scores,
                                       2)  # [num_mentions, max_ant]

        antecedent_mask = tf.log(
            tf.sequence_mask(antecedents_len,
                             max_antecedents,
                             dtype=tf.float32))  # [num_mentions, max_ant]
        antecedent_scores += antecedent_mask  # [num_mentions, max_ant]

        antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(
            mention_scores, antecedents)  # [num_mentions, max_ant]
        antecedent_scores = tf.concat(
            [tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores],
            1)  # [num_mentions, max_ant + 1]
        return antecedent_scores  # [num_mentions, max_ant + 1]